diff --git a/llama/patches/0029-vulkan-Call-ggml_vk_buffer_write_2d-from-ggml_vk_buf.patch b/llama/patches/0029-vulkan-Call-ggml_vk_buffer_write_2d-from-ggml_vk_buf.patch new file mode 100644 index 000000000..e9737aa41 --- /dev/null +++ b/llama/patches/0029-vulkan-Call-ggml_vk_buffer_write_2d-from-ggml_vk_buf.patch @@ -0,0 +1,32 @@ +From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 +From: Jeff Bolz +Date: Wed, 29 Oct 2025 03:53:04 -0500 +Subject: [PATCH] vulkan: Call ggml_vk_buffer_write_2d from ggml_vk_buffer_copy + (#16793) + +This lets the copy to the destination device use the host-visible +vidmem optimization. +--- + ggml/src/ggml-vulkan/ggml-vulkan.cpp | 5 +---- + 1 file changed, 1 insertion(+), 4 deletions(-) + +diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp +index 221e29509..18b7cbccf 100644 +--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp ++++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp +@@ -5654,14 +5654,11 @@ static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& sr + VK_LOG_DEBUG("ggml_vk_buffer_copy(MULTI_DEVICE, " << size << ")"); + // Copy device to device + ggml_vk_ensure_sync_staging_buffer(src->device, size); +- ggml_vk_ensure_sync_staging_buffer(dst->device, size); + + // Copy to src staging buffer + ggml_vk_buffer_copy(src->device->sync_staging, 0, src, src_offset, size); +- // memcpy to dst staging buffer +- memcpy(dst->device->sync_staging->ptr, src->device->sync_staging->ptr, size); + // Copy to dst buffer +- ggml_vk_buffer_copy(dst, dst_offset, dst->device->sync_staging, 0, size); ++ ggml_vk_buffer_write_2d(dst, dst_offset, src->device->sync_staging->ptr, 0, size, 1); + } + } + diff --git a/llama/patches/0030-Vulkan-MMQ-Integer-Dot-Refactor-and-K-Quant-support-.patch b/llama/patches/0030-Vulkan-MMQ-Integer-Dot-Refactor-and-K-Quant-support-.patch new file mode 100644 index 000000000..1b1f65e42 --- /dev/null +++ b/llama/patches/0030-Vulkan-MMQ-Integer-Dot-Refactor-and-K-Quant-support-.patch @@ -0,0 +1,2140 @@ +From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 +From: Ruben Ortlam +Date: Wed, 29 Oct 2025 14:39:03 +0100 +Subject: [PATCH] Vulkan MMQ Integer Dot Refactor and K-Quant support (#16536) + +* vulkan: add mmq q2_k integer dot support + +* Refactor mmq caching + +* Reduce mmq register use + +* Load 4 quant blocks into shared memory in one step + +* Pack q2_k blocks into caches of 32 + +* Use 32-bit accumulators for integer dot matmul + +* Add q4_k mmq + +* Add q3_k mmq + +* Add q5_k mmq + +* Add q6_k mmq + +* Add mxfp4 mmq, enable MMQ MUL_MAT_ID + +* Fix mmv dm loads +--- + ggml/src/ggml-vulkan/ggml-vulkan.cpp | 165 +++++- + .../vulkan-shaders/dequant_funcs.glsl | 10 +- + .../vulkan-shaders/dequant_funcs_cm2.glsl | 6 +- + .../vulkan-shaders/dequant_mxfp4.comp | 4 +- + .../vulkan-shaders/dequant_q2_k.comp | 4 +- + .../vulkan-shaders/dequant_q4_k.comp | 4 +- + .../vulkan-shaders/dequant_q5_k.comp | 4 +- + .../vulkan-shaders/mul_mat_vec_q2_k.comp | 6 +- + .../vulkan-shaders/mul_mat_vec_q4_k.comp | 6 +- + .../vulkan-shaders/mul_mat_vec_q5_k.comp | 6 +- + .../ggml-vulkan/vulkan-shaders/mul_mm.comp | 72 +-- + .../vulkan-shaders/mul_mm_funcs.glsl | 14 +- + .../vulkan-shaders/mul_mm_id_funcs.glsl | 70 +++ + .../ggml-vulkan/vulkan-shaders/mul_mmq.comp | 288 +++------- + .../vulkan-shaders/mul_mmq_funcs.glsl | 538 ++++++++++++++++-- + .../vulkan-shaders/mul_mmq_shmem_types.glsl | 78 +++ + .../src/ggml-vulkan/vulkan-shaders/types.glsl | 53 +- + .../vulkan-shaders/vulkan-shaders-gen.cpp | 5 +- + 18 files changed, 928 insertions(+), 405 deletions(-) + create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl + create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl + +diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp +index 18b7cbccf..53b57c179 100644 +--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp ++++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp +@@ -488,6 +488,7 @@ struct vk_device_struct { + vk_matmul_pipeline2 pipeline_matmul_id_f16_f32; + + vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id[GGML_TYPE_COUNT]; ++ vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_COUNT]; + + vk_pipeline pipeline_matmul_split_k_reduce; + vk_pipeline pipeline_quantize_q8_1; +@@ -2449,8 +2450,11 @@ static void ggml_vk_load_shaders(vk_device& device) { + l_warptile_id, m_warptile_id, s_warptile_id, + l_warptile_mmq, m_warptile_mmq, s_warptile_mmq, + l_warptile_mmq_int, m_warptile_mmq_int, s_warptile_mmq_int, ++ l_warptile_mmq_int_k, m_warptile_mmq_int_k, s_warptile_mmq_int_k, + l_warptile_mmq_k, m_warptile_mmq_k, s_warptile_mmq_k, +- l_warptile_mmqid, m_warptile_mmqid, s_warptile_mmqid; ++ l_warptile_mmqid, m_warptile_mmqid, s_warptile_mmqid, ++ l_warptile_mmqid_int, m_warptile_mmqid_int, s_warptile_mmqid_int, ++ l_warptile_mmqid_int_k, m_warptile_mmqid_int_k, s_warptile_mmqid_int_k; + std::array l_wg_denoms, m_wg_denoms, s_wg_denoms, + l_mmq_wg_denoms, m_mmq_wg_denoms, s_mmq_wg_denoms, + l_mmq_wg_denoms_k, m_mmq_wg_denoms_k, s_mmq_wg_denoms_k, +@@ -2513,10 +2517,16 @@ static void ggml_vk_load_shaders(vk_device& device) { + m_warptile_mmq = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 }; + s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 }; + ++ // Integer MMQ has a smaller shared memory profile, but heavier register use + l_warptile_mmq_int = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 }; + m_warptile_mmq_int = { 128, 64, 64, 32, subgroup_size_8, 32, 2, 2, 2, 1, subgroup_size_8 }; + s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, subgroup_size_8 }; + ++ // K-quants use even more registers, mitigate by setting WMITER to 1 ++ l_warptile_mmq_int_k = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 1, 4, 4, 1, subgroup_size_8 }; ++ m_warptile_mmq_int_k = { 128, 64, 64, 32, subgroup_size_8, 32, 1, 2, 2, 1, subgroup_size_8 }; ++ s_warptile_mmq_int_k = { subgroup_size_32, 32, 32, 32, 32, 32, 1, 2, 1, 1, subgroup_size_8 }; ++ + l_warptile_id = { 128, 128, 128, 16, mul_mat_subgroup_size_16 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_16 }; + m_warptile_id = { 128, 64, 64, 16, mul_mat_subgroup_size_16, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_16 }; + s_warptile_id = { mul_mat_subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_16 }; +@@ -2525,10 +2535,18 @@ static void ggml_vk_load_shaders(vk_device& device) { + m_warptile_mmqid = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_8 }; + s_warptile_mmqid = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_8 }; + ++ l_warptile_mmqid_int = { 128, 128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, 4, 4, 1, mul_mat_subgroup_size_8 }; ++ m_warptile_mmqid_int = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, 2, 2, 1, mul_mat_subgroup_size_8 }; ++ s_warptile_mmqid_int = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, mul_mat_subgroup_size_8 }; ++ ++ l_warptile_mmqid_int_k = { 128, 128, 128, 32, mul_mat_subgroup_size_16 * 2, 64, 1, 4, 4, 1, mul_mat_subgroup_size_16 }; ++ m_warptile_mmqid_int_k = { 128, 64, 64, 32, mul_mat_subgroup_size_16, 32, 1, 2, 2, 1, mul_mat_subgroup_size_16 }; ++ s_warptile_mmqid_int_k = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 1, 2, 1, 1, mul_mat_subgroup_size_16 }; ++ + // chip specific tuning + if ((device->architecture == AMD_GCN) && (device->driver_id != vk::DriverId::eAmdProprietary)) { + m_warptile_mmq = m_warptile_mmq_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 }; +- m_warptile_mmqid = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 }; ++ m_warptile_mmqid = m_warptile_mmqid_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 }; + } + + l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 }; +@@ -2913,18 +2931,15 @@ static void ggml_vk_load_shaders(vk_device& device) { + if (device->mul_mat ## ID ## _s[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + +-#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ ++#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \ + if (device->mul_mat ## ID ## _l[TYPE]) { \ +- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f16acc->l, #NAMELC "_f16acc_l", NAMELC ## _f16acc_len, NAMELC ## _f16acc_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ +- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->l, #NAMELC "_l", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ ++ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->l, #NAMELC "_l", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + } \ + if (device->mul_mat ## ID ## _m[TYPE]) { \ +- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f16acc->m, #NAMELC "_f16acc_m", NAMELC ## _f16acc_len, NAMELC ## _f16acc_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ +- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->m, #NAMELC "_m", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ ++ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->m, #NAMELC "_m", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + } \ + if (device->mul_mat ## ID ## _s[TYPE]) { \ +- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f16acc->s, #NAMELC "_f16acc_s", NAMELC ## _f16acc_len, NAMELC ## _f16acc_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ +- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->s, #NAMELC "_s", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ ++ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->s, #NAMELC "_s", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + } \ + + // Create 2 variants, {f16,f32} accumulator +@@ -2963,11 +2978,19 @@ static void ggml_vk_load_shaders(vk_device& device) { + + #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + if (device->integer_dot_product) { +- CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0], matmul_q4_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); +- CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1], matmul_q4_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); +- CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0], matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); +- CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1], matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); +- CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0], matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); ++ CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0], matmul_q4_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0); ++ CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1], matmul_q4_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0); ++ CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0], matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0); ++ CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1], matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0); ++ CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0], matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0); ++ ++ CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_MXFP4], matmul_mxfp4_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0); ++ ++ CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q2_K], matmul_q2_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0); ++ CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q3_K], matmul_q3_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0); ++ CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_K], matmul_q4_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0); ++ CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_K], matmul_q5_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0); ++ CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q6_K], matmul_q6_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0); + } + #endif + +@@ -2997,6 +3020,24 @@ static void ggml_vk_load_shaders(vk_device& device) { + CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); ++ ++#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) ++ if (device->integer_dot_product) { ++ CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); ++ CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); ++ CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); ++ CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); ++ CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); ++ ++ CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); ++ ++ CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16); ++ CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16); ++ CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16); ++ CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16); ++ CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16); ++ } ++#endif + } else { + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0); +@@ -3023,6 +3064,24 @@ static void ggml_vk_load_shaders(vk_device& device) { + CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); ++ ++#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) ++ if (device->integer_dot_product) { ++ CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_0], matmul_id_q4_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0); ++ CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_1], matmul_id_q4_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0); ++ CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_0], matmul_id_q5_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0); ++ CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_1], matmul_id_q5_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0); ++ CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q8_0], matmul_id_q8_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0); ++ ++ CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_MXFP4], matmul_id_mxfp4_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0); ++ ++ CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q2_K], matmul_id_q2_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0); ++ CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q3_K], matmul_id_q3_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0); ++ CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_K], matmul_id_q4_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0); ++ CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_K], matmul_id_q5_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0); ++ CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q6_K], matmul_id_q6_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0); ++ } ++#endif + } + #undef CREATE_MM2 + #undef CREATE_MMQ +@@ -3087,6 +3146,12 @@ static void ggml_vk_load_shaders(vk_device& device) { + CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); ++ ++ CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, ); ++ CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, ); ++ CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, ); ++ CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, ); ++ CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, ); + } + #endif + +@@ -3146,7 +3211,7 @@ static void ggml_vk_load_shaders(vk_device& device) { + } + // reusing CREATE_MM from the fp32 path + if ((device->coopmat2 || device->coopmat_support) +-#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) ++#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + && !device->coopmat_bf16_support + #endif + ) { +@@ -4930,7 +4995,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte + + // MMQ + if (src1_type == GGML_TYPE_Q8_1) { +- vk_matmul_pipeline pipelines = (ctx->device->fp16 && prec == GGML_PREC_DEFAULT) ? ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f32acc; ++ vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f32acc; + + if (pipelines->s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr) { + return nullptr; +@@ -5077,6 +5142,17 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co + } + } + ++ // MMQ ++ if (src1_type == GGML_TYPE_Q8_1) { ++ vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_id_q8_1[src0_type].f32acc; ++ ++ if (pipelines->s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr) { ++ return nullptr; ++ } ++ ++ return pipelines; ++ } ++ + GGML_ASSERT(src1_type == GGML_TYPE_F32 || (ctx->device->coopmat2 && src1_type == GGML_TYPE_F16)); + + switch (src0_type) { +@@ -6879,10 +6955,19 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& + + const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig; + +- vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? f16_type : src1->type, (ggml_prec)dst->op_params[0]); ++ bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0; ++ ++ // Check for mmq first ++ vk_matmul_pipeline mmp = quantize_y ? ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, GGML_TYPE_Q8_1, (ggml_prec)dst->op_params[0]) : nullptr; ++ ++ if (mmp == nullptr) { ++ // Fall back to f16 dequant mul mat ++ mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? f16_type : src1->type, (ggml_prec)dst->op_params[0]); ++ quantize_y = false; ++ } + + const bool qx_needs_dequant = mmp == nullptr || x_non_contig; +- const bool qy_needs_dequant = (src1->type != f16_type && !y_f32_kernel) || y_non_contig; ++ const bool qy_needs_dequant = !quantize_y && ((src1->type != f16_type && !y_f32_kernel) || y_non_contig); + + if (qx_needs_dequant) { + // Fall back to dequant + f16 mulmat +@@ -6892,8 +6977,8 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& + // Not implemented + GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT + +- const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? f16_type : src0->type)); +- const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8; ++ const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? f16_type : src0->type)); ++ const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && nei1 > 8; + + vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? f16_type : src0->type); + +@@ -6906,12 +6991,13 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& + const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type); + const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); + const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne; +- const uint64_t y_sz = y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne; ++ const uint64_t y_sz = quantize_y ? (y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : (y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne); + const uint64_t ids_sz = nbi2; + const uint64_t d_sz = sizeof(float) * d_ne; + + vk_pipeline to_fp16_vk_0 = nullptr; + vk_pipeline to_fp16_vk_1 = nullptr; ++ vk_pipeline to_q8_1 = nullptr; + + if (x_non_contig) { + to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, f16_type); +@@ -6926,9 +7012,16 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& + GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT + GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT + ++ if (quantize_y) { ++ to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1, true); ++ } ++ + if (dryrun) { + const uint64_t x_sz_upd = x_sz * ne02 * ne03; +- const uint64_t y_sz_upd = y_sz * ne12 * ne13; ++ uint64_t y_sz_upd = y_sz * ne12 * ne13; ++ if (quantize_y) { ++ y_sz_upd = CEIL_DIV(y_sz_upd, 144) * 144; ++ } + if ( + (qx_needs_dequant && x_sz_upd > ctx->device->properties.limits.maxStorageBufferRange) || + (qy_needs_dequant && y_sz_upd > ctx->device->properties.limits.maxStorageBufferRange)) { +@@ -6937,7 +7030,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& + if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) { + ctx->prealloc_size_x = x_sz_upd; + } +- if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) { ++ if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz_upd) { + ctx->prealloc_size_y = y_sz_upd; + } + +@@ -6949,6 +7042,9 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& + if (qy_needs_dequant) { + ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1); + } ++ if (quantize_y) { ++ ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1); ++ } + return; + } + +@@ -6985,6 +7081,9 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& + if (qy_needs_dequant) { + d_Y = ctx->prealloc_y; + GGML_ASSERT(d_Y->size >= y_sz * ne12 * ne13); ++ } else if (quantize_y) { ++ d_Y = ctx->prealloc_y; ++ GGML_ASSERT(d_Y->size >= CEIL_DIV(y_sz * ne12 * ne13, 144) * 144); + } else { + d_Y = d_Qy; + y_buf_offset = qy_buf_offset; +@@ -7016,6 +7115,17 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& + ctx->prealloc_y_last_tensor_used = src1; + } + } ++ if (quantize_y) { ++ if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() || ++ ctx->prealloc_y_last_tensor_used != src1) { ++ if (ctx->prealloc_y_need_sync) { ++ ggml_vk_sync_buffers(ctx, subctx); ++ } ++ ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0), y_ne * ne12 * ne13, true); ++ ctx->prealloc_y_last_pipeline_used = to_q8_1.get(); ++ ctx->prealloc_y_last_tensor_used = src1; ++ } ++ } + + uint32_t stride_batch_x = ne00*ne01; + uint32_t stride_batch_y = ne10*ne11; +@@ -7024,14 +7134,19 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& + stride_batch_x = src0->nb[0] / ggml_type_size(src0->type); + } + +- if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) { ++ if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant && !quantize_y) { + stride_batch_y = src1->nb[0] / ggml_type_size(src1->type); + } + ++ uint32_t y_sz_total = y_sz * ne12 * ne13; ++ if (quantize_y) { ++ y_sz_total = CEIL_DIV(y_sz_total, 144) * 144; ++ } ++ + // compute + ggml_vk_matmul_id( + ctx, subctx, pipeline, +- { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 }, ++ { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz_total }, + { d_D, d_buf_offset, d_sz * ne22 * ne23 }, { d_ids, ids_buf_offset, ids_sz }, + ne01, ne21, ne10, ne10, ne10, ne01, + stride_batch_x, stride_batch_y, ne20*ne21, +diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +index 0d98f5a9d..09676a623 100644 +--- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl ++++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +@@ -437,7 +437,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + #if defined(DATA_A_MXFP4) + vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint vui = uint(data_a[a_offset + ib].qs[iqs]); +- return vec2(kvalues_mxfp4[vui & 0xF], kvalues_mxfp4[vui >> 4]); ++ return vec2(kvalues_mxfp4[vui & 0xF], kvalues_mxfp4[vui >> 4]) * 0.5; + } + vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + vec2 v0 = dequantize(ib, iqs, a_offset); +@@ -488,9 +488,9 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) { + + const uvec2 qs = uvec2(data_a[a_offset + ib].qs[qsi], data_a[a_offset + ib].qs[qsi + 1]); + const uint scales = data_a[a_offset + ib].scales[scalesi]; +- const vec2 d = vec2(data_a[a_offset + ib].d); ++ const vec2 dm = vec2(data_a[a_offset + ib].dm); + +- return d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4); ++ return dm.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - dm.y * float(scales >> 4); + } + vec2 get_dm(uint ib, uint a_offset) { + return vec2(1, 0); +@@ -529,7 +529,7 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint is = 2 * n + b; // 0..7 + const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126 + +- const vec2 loadd = vec2(data_a[a_offset + ib].d); ++ const vec2 loadd = vec2(data_a[a_offset + ib].dm); + + const uint scidx0 = (is < 4) ? is : (is + 4); + const uint scidx1 = (is < 4) ? is : (is - 4); +@@ -567,7 +567,7 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) { + + const uint8_t hm = uint8_t(1 << (iqs / 16)); + +- const vec2 loadd = vec2(data_a[a_offset + ib].d); ++ const vec2 loadd = vec2(data_a[a_offset + ib].dm); + + const uint scidx0 = (is < 4) ? is : (is + 4); + const uint scidx1 = (is < 4) ? is : (is - 4); +diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +index 67baedf7c..8ac6482dc 100644 +--- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl ++++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +@@ -120,7 +120,7 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ2 + float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) + { + decodeBufQ2_K_packed16 bl16 = decodeBufQ2_K_packed16(bl); +- const f16vec2 d = bl.block.d; ++ const f16vec2 dm = bl.block.dm; + const uint idx = coordInBlock[1]; + + const uint scalesi = (idx & 0xF0) >> 4; // 0..15 +@@ -131,7 +131,7 @@ float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2 + qs = unpack8(qs)[idx & 1]; + + const uint scales = bl.block.scales[scalesi]; +- float16_t ret = d.x * float16_t(scales & 0xF) * float16_t(qs) - d.y * float16_t(scales >> 4); ++ float16_t ret = dm.x * float16_t(scales & 0xF) * float16_t(qs) - dm.y * float16_t(scales >> 4); + return ret; + } + +@@ -680,7 +680,7 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords + uint32_t qs = bl.block.qs[iqs]; + qs >>= shift; + qs &= 0xF; +- float16_t ret = float16_t(kvalues_mxfp4[qs] * d); ++ float16_t ret = float16_t(kvalues_mxfp4[qs] * d * 0.5); + return ret; + } + #endif +diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +index ffba5a77d..3194ba291 100644 +--- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp ++++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +@@ -26,7 +26,7 @@ void main() { + const float d = e8m0_to_fp32(data_a[ib].e); + + [[unroll]] for (uint l = 0; l < 8; ++l) { +- data_b[b_idx + l + 0] = D_TYPE(d * kvalues_mxfp4[data_a[ib].qs[q_idx + l] & 0xF]); +- data_b[b_idx + l + 16] = D_TYPE(d * kvalues_mxfp4[data_a[ib].qs[q_idx + l] >> 4]); ++ data_b[b_idx + l + 0] = D_TYPE(d * 0.5 * float(kvalues_mxfp4[data_a[ib].qs[q_idx + l] & 0xF])); ++ data_b[b_idx + l + 16] = D_TYPE(d * 0.5 * float(kvalues_mxfp4[data_a[ib].qs[q_idx + l] >> 4])); + } + } +diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +index 58dc2e5df..dc05a7834 100644 +--- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp ++++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +@@ -24,8 +24,8 @@ void main() { + const uint ql_idx = 32 * ip + il; + const uint8_t qs = data_a[i].qs[32 * ip + il]; + +- FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x); +- FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y); ++ FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].dm.x); ++ FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].dm.y); + data_b[y_idx + 0] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+0] & 0xF) * ((qs >> 0) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+0] >> 4)); + data_b[y_idx + 32] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+2] & 0xF) * ((qs >> 2) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+2] >> 4)); + data_b[y_idx + 64] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+4] & 0xF) * ((qs >> 4) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+4] >> 4)); +diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +index 8b7be557e..0f23dc0a3 100644 +--- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp ++++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +@@ -20,8 +20,8 @@ void main() { + const uint is = 2 * il; + const uint n = 4; + +- const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].d.x); +- const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].d.y); ++ const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].dm.x); ++ const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].dm.y); + + const uint y_idx = ib * QUANT_K + 64 * il + n * ir; + const uint qs_idx = 32*il + n * ir; +diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +index 6bc04670f..970469a60 100644 +--- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp ++++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +@@ -19,8 +19,8 @@ void main() { + const uint ir = tid % 16; + const uint is = 2 * il; + +- const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].d.x); +- const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].d.y); ++ const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].dm.x); ++ const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].dm.y); + + const uint y_idx = ib * QUANT_K + 64 * il + 2 * ir; + const uint qs_idx = 32*il + 2 * ir; +diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +index 03ed25d3b..14093c0de 100644 +--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp ++++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +@@ -41,9 +41,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, + const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303)); + const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303)); + +- vec2 d = vec2(data_a[ib0 + i].d); +- const FLOAT_TYPE dall = FLOAT_TYPE(d.x); +- const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); ++ const FLOAT_TYPE_VEC2 dm = vec2(data_a[ib0 + i].dm); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]); +@@ -75,7 +73,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, + fma(FLOAT_TYPE(b96[l]), sccache2[csel][ix][6 + 8*v_im], + fma(FLOAT_TYPE(b112[l]), sccache2[csel][ix][7 + 8*v_im], sum2)))))))); + } +- temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n])); ++ temp[j][n] = fma(dm.x, sum1, fma(-dm.y, sum2, temp[j][n])); + } + } + } +diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +index 21d07d2e5..49d91ad59 100644 +--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp ++++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +@@ -14,9 +14,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im, + + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; +- vec2 d = vec2(data_a[ib0 + i].d); +- const FLOAT_TYPE dall = FLOAT_TYPE(d.x); +- const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); ++ const FLOAT_TYPE_VEC2 dm = FLOAT_TYPE_VEC2(data_a[ib0 + i].dm); + + const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; + const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2]; +@@ -81,7 +79,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im, + fma(FLOAT_TYPE(by10.y), sc2, fma(FLOAT_TYPE(by132.y), sc3, fma(FLOAT_TYPE(by20.y), sc6, fma(FLOAT_TYPE(by232.y), sc7, + fma(FLOAT_TYPE(by10.z), sc2, fma(FLOAT_TYPE(by132.z), sc3, fma(FLOAT_TYPE(by20.z), sc6, fma(FLOAT_TYPE(by232.z), sc7, + fma(FLOAT_TYPE(by10.w), sc2, fma(FLOAT_TYPE(by132.w), sc3, fma(FLOAT_TYPE(by20.w), sc6, FLOAT_TYPE(by232.w) * sc7))))))))))))))); +- temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n])); ++ temp[j][n] = fma(dm.x, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dm.y, smin, temp[j][n])); + } + } + } +diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +index 9e46c89a1..0d61b4966 100644 +--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp ++++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +@@ -14,9 +14,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im, + + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; +- vec2 d = vec2(data_a[ib0 + i].d); +- const FLOAT_TYPE dall = FLOAT_TYPE(d.x); +- const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); ++ const FLOAT_TYPE_VEC2 dm = FLOAT_TYPE_VEC2(data_a[ib0 + i].dm); + + const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; + const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2]; +@@ -113,7 +111,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im, + fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3, + fma(FLOAT_TYPE(by20.x) + FLOAT_TYPE(by20.y) + FLOAT_TYPE(by216.x) + FLOAT_TYPE(by216.y), sc6, + (FLOAT_TYPE(by232.x) + FLOAT_TYPE(by232.y) + FLOAT_TYPE(by248.x) + FLOAT_TYPE(by248.y)) * sc7))); +- temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n])); ++ temp[j][n] = fma(dm.x, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dm.y, smin, temp[j][n])); + } + } + } +diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +index a20788c4b..d260969f0 100644 +--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp ++++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +@@ -120,81 +120,11 @@ shared FLOAT_TYPE_VEC2 buf_b[BN * SHMEM_STRIDE]; + + #define NUM_WARPS (BLOCK_SIZE / WARP) + +-#ifdef MUL_MAT_ID +-shared u16vec2 row_ids[BN]; +-uint _ne1; +- +-#ifdef MUL_MAT_ID_USE_SUBGROUPS +-shared uvec4 ballots_sh[NUM_WARPS]; +- +-void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) { +- _ne1 = 0; +- uint num_elements = p.nei1 * p.nei0; +- uint nei0shift = findLSB(p.nei0); +- +- uint ids[16]; +- uint iter = 0; +- +- for (uint j = 0; j < num_elements; j += BLOCK_SIZE) { +- // prefetch up to 16 elements +- if (iter == 0) { +- [[unroll]] for (uint k = 0; k < 16; ++k) { +- uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE; +- bool in_range = i < num_elements; +- uint ii1; +- if (nei0_is_pow2) { +- ii1 = i >> nei0shift; +- } else { +- ii1 = i / p.nei0; +- } +- uint ii0 = i - ii1 * p.nei0; +- ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0; +- } +- } +- uint i = j + gl_LocalInvocationIndex; +- bool in_range = i < num_elements; +- uint ii1; +- if (nei0_is_pow2) { +- ii1 = i >> nei0shift; +- } else { +- ii1 = i / p.nei0; +- } +- uint ii0 = i - ii1 * p.nei0; +- uint id = ids[iter++]; +- uvec4 ballot = subgroupBallot(in_range && id == expert_idx); +- +- ballots_sh[gl_SubgroupID] = ballot; +- barrier(); +- +- uint subgroup_base = 0; +- uint total = 0; +- for (uint k = 0; k < gl_NumSubgroups; ++k) { +- if (k == gl_SubgroupID) { +- subgroup_base = total; +- } +- total += subgroupBallotBitCount(ballots_sh[k]); +- } +- barrier(); +- +- uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot); +- if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) { +- row_ids[_ne1 + idx - ic * BN] = u16vec2(ii0, ii1); +- } +- _ne1 += total; +- iter &= 15; +- if (_ne1 >= (ic + 1) * BN) { +- break; +- } +- } +- barrier(); +-} +-#endif // MUL_MAT_ID_USE_SUBGROUPS +-#endif // MUL_MAT_ID +- + #ifdef COOPMAT + shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS]; + #endif + ++#include "mul_mm_id_funcs.glsl" + #include "mul_mm_funcs.glsl" + + void main() { +diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +index 0ebfbd646..ee5ded2e8 100644 +--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl ++++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +@@ -134,15 +134,15 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin + const uint ib = idx / 128; // 2 values per idx + const uint iqs = idx % 128; // 0..127 + +- const uint qsi = (iqs / 64) * 32 + (iqs % 16) * 2; // 0,2,4..30 ++ const uint qsi = (iqs / 64) * 16 + (iqs % 16); // 0..15 + const uint scalesi = iqs / 8; // 0..15 + const uint qsshift = ((iqs % 64) / 16) * 2; // 0,2,4,6 + +- const uvec2 qs = uvec2(data_a[ib].qs[qsi], data_a[ib].qs[qsi + 1]); ++ const uvec2 qs = uvec2(unpack8(data_a_packed16[ib].qs[qsi])); + const uint scales = data_a[ib].scales[scalesi]; +- const vec2 d = vec2(data_a[ib].d); ++ const vec2 dm = vec2(data_a[ib].dm); + +- const vec2 v = d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4); ++ const vec2 v = dm.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - dm.y * float(scales >> 4); + + buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy); + #elif defined(DATA_A_Q3_K) +@@ -179,7 +179,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin + const uint is = 2 * n + b; // 0..7 + const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126 + +- const vec2 loadd = vec2(data_a[ib].d); ++ const vec2 loadd = vec2(data_a[ib].dm); + + const uint scidx0 = (is < 4) ? is : (is + 4); + const uint scidx1 = (is < 4) ? is : (is - 4); +@@ -215,7 +215,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin + + const uint8_t hm = uint8_t(1 << (iqs / 16)); + +- const vec2 loadd = vec2(data_a[ib].d); ++ const vec2 loadd = vec2(data_a[ib].dm); + + const uint scidx0 = (is < 4) ? is : (is + 4); + const uint scidx1 = (is < 4) ? is : (is - 4); +@@ -468,7 +468,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin + const uint ib = idx / 8; + const uint iqs = (idx & 0x07) * 2; + +- const float d = e8m0_to_fp32(data_a[ib].e); ++ const float d = e8m0_to_fp32(data_a[ib].e) * 0.5; + const uint vui = uint(data_a[ib].qs[iqs]); + const uint vui2 = uint(data_a[ib].qs[iqs+1]); + +diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +new file mode 100644 +index 000000000..1d0e84ac9 +--- /dev/null ++++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +@@ -0,0 +1,70 @@ ++#ifdef MUL_MAT_ID ++shared u16vec2 row_ids[BN]; ++uint _ne1; ++ ++#ifdef MUL_MAT_ID_USE_SUBGROUPS ++shared uvec4 ballots_sh[NUM_WARPS]; ++ ++void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) { ++ _ne1 = 0; ++ uint num_elements = p.nei1 * p.nei0; ++ uint nei0shift = findLSB(p.nei0); ++ ++ uint ids[16]; ++ uint iter = 0; ++ ++ for (uint j = 0; j < num_elements; j += BLOCK_SIZE) { ++ // prefetch up to 16 elements ++ if (iter == 0) { ++ [[unroll]] for (uint k = 0; k < 16; ++k) { ++ uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE; ++ bool in_range = i < num_elements; ++ uint ii1; ++ if (nei0_is_pow2) { ++ ii1 = i >> nei0shift; ++ } else { ++ ii1 = i / p.nei0; ++ } ++ uint ii0 = i - ii1 * p.nei0; ++ ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0; ++ } ++ } ++ uint i = j + gl_LocalInvocationIndex; ++ bool in_range = i < num_elements; ++ uint ii1; ++ if (nei0_is_pow2) { ++ ii1 = i >> nei0shift; ++ } else { ++ ii1 = i / p.nei0; ++ } ++ uint ii0 = i - ii1 * p.nei0; ++ uint id = ids[iter++]; ++ uvec4 ballot = subgroupBallot(in_range && id == expert_idx); ++ ++ ballots_sh[gl_SubgroupID] = ballot; ++ barrier(); ++ ++ uint subgroup_base = 0; ++ uint total = 0; ++ for (uint k = 0; k < gl_NumSubgroups; ++k) { ++ if (k == gl_SubgroupID) { ++ subgroup_base = total; ++ } ++ total += subgroupBallotBitCount(ballots_sh[k]); ++ } ++ barrier(); ++ ++ uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot); ++ if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) { ++ row_ids[_ne1 + idx - ic * BN] = u16vec2(ii0, ii1); ++ } ++ _ne1 += total; ++ iter &= 15; ++ if (_ne1 >= (ic + 1) * BN) { ++ break; ++ } ++ } ++ barrier(); ++} ++#endif // MUL_MAT_ID_USE_SUBGROUPS ++#endif // MUL_MAT_ID +diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +index b5d761c0b..8b238ac4b 100644 +--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp ++++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +@@ -10,10 +10,9 @@ + #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require + #endif + +-#ifdef COOPMAT +-#extension GL_KHR_cooperative_matrix : enable +-#extension GL_KHR_memory_scope_semantics : enable ++#if defined(MUL_MAT_ID_USE_SUBGROUPS) + #extension GL_KHR_shader_subgroup_basic : enable ++#extension GL_KHR_shader_subgroup_ballot : enable + #endif + + #ifdef MUL_MAT_ID +@@ -24,7 +23,10 @@ + + layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +-layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];}; ++layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; ++#if defined(A_TYPE_PACKED16) ++layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];}; ++#endif + #if defined(A_TYPE_PACKED32) + layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];}; + #endif +@@ -76,40 +78,27 @@ layout (constant_id = 10) const uint WARP = 32; + + #define BK 32 + +-#ifdef COOPMAT +-#define SHMEM_STRIDE (BK / 4 + 4) +-#else +-#define SHMEM_STRIDE (BK / 4 + 1) +-#endif ++#define MMQ_SHMEM + +-shared int32_t buf_a_qs[BM * SHMEM_STRIDE]; ++#include "mul_mmq_shmem_types.glsl" + +-#ifndef COOPMAT +-#if QUANT_AUXF == 1 +-shared FLOAT_TYPE buf_a_dm[BM]; +-#else +-shared FLOAT_TYPE_VEC2 buf_a_dm[BM]; +-#endif ++#ifndef BK_STEP ++#define BK_STEP 4 + #endif + +-shared int32_t buf_b_qs[BN * SHMEM_STRIDE]; +-#ifndef COOPMAT +-shared FLOAT_TYPE_VEC2 buf_b_ds[BN]; +-#endif ++// Shared memory cache ++shared block_a_cache buf_a[BM * BK_STEP]; ++shared block_b_cache buf_b[BN * BK_STEP]; ++// Register cache ++block_a_cache cache_a[WMITER * TM]; ++block_b_cache cache_b; + +-#define LOAD_VEC_A (4 * QUANT_R) ++#define LOAD_VEC_A (4 * QUANT_R_MMQ) + #define LOAD_VEC_B 16 + +-#ifdef MUL_MAT_ID +-shared u16vec2 row_ids[4096]; +-#endif // MUL_MAT_ID +- + #define NUM_WARPS (BLOCK_SIZE / WARP) + +-#ifdef COOPMAT +-shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS]; +-#endif +- ++#include "mul_mm_id_funcs.glsl" + #include "mul_mmq_funcs.glsl" + + void main() { +@@ -139,26 +128,12 @@ void main() { + const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER); + const uint WSUBM = WM / WMITER; + const uint WSUBN = WN / WNITER; +- +-#ifdef COOPMAT +- const uint warp_i = gl_SubgroupID; +- +- const uint tiw = gl_SubgroupInvocationID; +- +- const uint cms_per_row = WM / TM; +- const uint cms_per_col = WN / TN; +- +- const uint storestride = WARP / TM; +- const uint store_r = tiw % TM; +- const uint store_c = tiw / TM; +-#else + const uint warp_i = gl_LocalInvocationID.x / WARP; + + const uint tiw = gl_LocalInvocationID.x % WARP; + + const uint tiwr = tiw % (WSUBM / TM); + const uint tiwc = tiw / (WSUBM / TM); +-#endif + + const uint warp_r = warp_i % (BM / WM); + const uint warp_c = warp_i / (BM / WM); +@@ -172,17 +147,27 @@ void main() { + const uint loadstride_b = BLOCK_SIZE * LOAD_VEC_B / BK; + + #ifdef MUL_MAT_ID +- uint _ne1 = 0; +- for (uint ii1 = 0; ii1 < p.nei1; ii1++) { +- for (uint ii0 = 0; ii0 < p.nei0; ii0++) { ++#ifdef MUL_MAT_ID_USE_SUBGROUPS ++ if (bitCount(p.nei0) == 1) { ++ load_row_ids(expert_idx, true, ic); ++ } else { ++ load_row_ids(expert_idx, false, ic); ++ } ++#else ++ _ne1 = 0; ++ for (uint ii1 = 0; ii1 < p.nei1 && _ne1 < (ic + 1) * BN; ii1++) { ++ for (uint ii0 = 0; ii0 < p.nei0 && _ne1 < (ic + 1) * BN; ii0++) { + if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) { +- row_ids[_ne1] = u16vec2(ii0, ii1); ++ if (_ne1 >= ic * BN) { ++ row_ids[_ne1 - ic * BN] = u16vec2(ii0, ii1); ++ } + _ne1++; + } + } + } + + barrier(); ++#endif + + // Workgroup has no work + if (ic * BN >= _ne1) return; +@@ -209,159 +194,70 @@ void main() { + uint pos_b_ib = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / BK; + #endif + +-#ifdef COOPMAT +- coopmat cache_a; +- coopmat cache_b; +- coopmat cm_result; +- +- coopmat factors[cms_per_row * cms_per_col]; +- +- coopmat sums[cms_per_row * cms_per_col]; +- +- [[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) { +- sums[i] = coopmat(0.0f); +- } +-#else +- int32_t cache_a_qs[WMITER * TM * BK / 4]; +- +- int32_t cache_b_qs[TN * BK / 4]; +- + ACC_TYPE sums[WMITER * TM * WNITER * TN]; + + [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) { + sums[i] = ACC_TYPE(0.0f); + } +-#endif + +-#if QUANT_AUXF == 1 +- FLOAT_TYPE cache_a_dm[WMITER * TM]; +-#else +- FLOAT_TYPE_VEC2 cache_a_dm[WMITER * TM]; +-#endif +- +- FLOAT_TYPE_VEC2 cache_b_ds[TN]; +- +- for (uint block = start_k; block < end_k; block += BK) { ++ for (uint block = start_k; block < end_k; block += BK * BK_STEP) { + [[unroll]] for (uint l = 0; loadc_a + l < BM; l += loadstride_a) { +- const uint ib = pos_a_ib + (loadc_a + l) * p.stride_a / BK; +- const uint iqs = loadr_a; + const uint buf_ib = loadc_a + l; ++ const uint ib = pos_a_ib + buf_ib * p.stride_a / BK; ++ const uint iqs = loadr_a; + +- if (iqs == 0) { +-#if QUANT_AUXF == 1 +- buf_a_dm[buf_ib] = get_d(ib); +-#else +- buf_a_dm[buf_ib] = get_dm(ib); +-#endif ++ [[unroll]] for (uint k_step = 0; k_step < BK_STEP; k_step++) { ++ block_a_to_shmem(k_step * BM + buf_ib, ib + k_step, iqs); + } +-#if QUANT_R == 1 +- buf_a_qs[buf_ib * SHMEM_STRIDE + iqs] = repack(ib, iqs); +-#else +- const i32vec2 vals = repack(ib, iqs); +- buf_a_qs[buf_ib * SHMEM_STRIDE + iqs ] = vals.x; +- buf_a_qs[buf_ib * SHMEM_STRIDE + iqs + 4] = vals.y; +-#endif + } + [[unroll]] for (uint l = 0; loadc_b + l < BN; l += loadstride_b) { ++ const uint buf_ib = loadc_b + l; ++ + #ifdef MUL_MAT_ID +- const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l]; +- const uint idx = pos_b_ib + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b; +- const uint ib = idx / 8; +- const uint iqs = idx & 0x7; ++ const u16vec2 row_idx = row_ids[buf_ib]; ++ const uint ib = pos_b_ib + row_idx.y * p.batch_stride_b / BK + (row_idx.x % p.ne11) * p.stride_b / BK; + #else +- const uint ib = pos_b_ib + (loadc_b + l) * p.stride_b / BK; +- const uint ib_outer = ib / 4; +- const uint ib_inner = ib % 4; +- +- const uint iqs = loadr_b; ++ const uint ib = pos_b_ib + buf_ib * p.stride_b / BK; + #endif ++ const uint iqs = loadr_b; + +- const uint buf_ib = loadc_b + l; +- +- if (iqs == 0) { +- buf_b_ds[buf_ib] = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]); ++ [[unroll]] for (uint k_step = 0; k_step < BK_STEP; k_step++) { ++ block_b_to_shmem(k_step * BN + buf_ib, ib + k_step, iqs); + } +- const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs]; +- buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 ] = values.x; +- buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 1] = values.y; +- buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 2] = values.z; +- buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 3] = values.w; + } + + barrier(); + +- pos_a_ib += 1; +- pos_b_ib += 1; ++ pos_a_ib += BK_STEP; ++ pos_b_ib += BK_STEP; + +-#ifdef COOPMAT +- [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { +- const uint ib_a = warp_r * WM + cm_row * TM; ++ for (uint k_step = 0; k_step < BK_STEP; k_step++) { + // Load from shared into cache +- coopMatLoad(cache_a, buf_a_qs, ib_a * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor); +- +- // TODO: only cache values that are actually needed +- [[unroll]] for (uint t_idx = 0; t_idx < TM; t_idx++) { +- cache_a_dm[t_idx] = buf_a_dm[ib_a + t_idx]; +- } +- +- [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { +- const uint ib_b = warp_c * WN + cm_col * TN; +- coopMatLoad(cache_b, buf_b_qs, ib_b * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor); +- +- // TODO: only cache values that are actually needed +- [[unroll]] for (uint t_idx = 0; t_idx < TN; t_idx++) { +- cache_b_dm[t_idx] = buf_b_d[ib_b + t_idx]; +- } +- +- cm_result = coopmat(0); +- cm_result = coopMatMulAdd(cache_a, cache_b, cm_result); +- +- [[unroll]] for (uint col = 0; col < TN; col += storestride) { +- coopmat_stage[warp_i * TM * TN + (store_c + col) * TM + store_r] = ACC_TYPE(float(cache_a_d[store_r]) * float(cache_b_d[store_c + col])); +- } +- +- coopMatLoad(factors, coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); +- sums[cm_col * cms_per_row + cm_row] += factors * coopmat(cm_result); +- } +- } +-#else +- // Load from shared into cache +- [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { +- [[unroll]] for (uint cr = 0; cr < TM; cr++) { +- const uint ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr; +- cache_a_dm[wsir * TM + cr] = buf_a_dm[ib]; +- [[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) { +- cache_a_qs[(wsir * TM + cr) * (BK / 4) + idx_k] = buf_a_qs[ib * SHMEM_STRIDE + idx_k]; +- } +- } +- } ++ [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { ++ [[unroll]] for (uint cr = 0; cr < TM; cr++) { ++ const uint reg_ib = wsir * TM + cr; ++ const uint buf_ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr; + +- [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { +- [[unroll]] for (uint cc = 0; cc < TN; cc++) { +- const uint ib = warp_c * WN + wsic * WSUBN + tiwc * TN + cc; +- cache_b_ds[cc] = buf_b_ds[ib]; +- [[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) { +- cache_b_qs[cc * (BK / 4) + idx_k] = buf_b_qs[ib * SHMEM_STRIDE + idx_k]; ++ block_a_to_registers(reg_ib, k_step * BM + buf_ib); + } + } + +- [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { ++ [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { + [[unroll]] for (uint cc = 0; cc < TN; cc++) { +- [[unroll]] for (uint cr = 0; cr < TM; cr++) { +- const uint cache_a_idx = wsir * TM + cr; +- const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr; +- int32_t q_sum = 0; +- [[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) { +- q_sum += dotPacked4x8EXT(cache_a_qs[cache_a_idx * (BK / 4) + idx_k], +- cache_b_qs[cc * (BK / 4) + idx_k]); +- } ++ const uint ib = k_step * BN + warp_c * WN + wsic * WSUBN + tiwc * TN + cc; ++ block_b_to_registers(ib); + +- sums[sums_idx] += mul_q8_1(q_sum, cache_a_dm[cache_a_idx], cache_b_ds[cc], 1); ++ [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { ++ [[unroll]] for (uint cr = 0; cr < TM; cr++) { ++ const uint cache_a_idx = wsir * TM + cr; ++ const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr; ++ ++ sums[sums_idx] += mmq_dot_product(cache_a_idx); ++ } + } + } + } + } +-#endif + + barrier(); + } +@@ -373,54 +269,6 @@ void main() { + const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z; + #endif + +-#ifdef COOPMAT +-#ifdef MUL_MAT_ID +- [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { +- [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { +- coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); +- +- [[unroll]] for (uint col = 0; col < BN; col += storestride) { +- const uint row_i = dc + cm_col * TN + col + store_c; +- if (row_i >= _ne1) break; +- +- const u16vec2 row_idx = row_ids[row_i]; +- +- data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); +- } +- } +- } +-#else +- const bool is_aligned = p.stride_d % 4 == 0; // Assumption: D_TYPE == float +- +- [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { +- [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { +- const bool is_in_bounds = dr + (cm_row + 1) * TM <= p.M && dc + (cm_col + 1) * TN <= p.N; +- +- if (is_aligned && is_in_bounds) { +- // Full coopMat is within bounds and stride_d is aligned with 16B +- coopmat cm_dtype = coopmat(sums[cm_col * cms_per_row + cm_row]); +- coopMatStore(cm_dtype, data_d, offsets + (dc + cm_col * TN) * p.stride_d + dr + cm_row * TM, p.stride_d, gl_CooperativeMatrixLayoutColumnMajor); +- } else if (is_in_bounds) { +- // Full coopMat is within bounds, but stride_d is not aligned +- coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); +- +- [[unroll]] for (uint col = 0; col < TN; col += storestride) { +- data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); +- } +- } else if (dr + cm_row * TM < p.M && dc + cm_col * TN < p.N) { +- // Partial coopMat is within bounds +- coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); +- +- [[unroll]] for (uint col = 0; col < TN; col += storestride) { +- if (dr + cm_row * TM + store_r < p.M && dc + cm_col * TN + col + store_c < p.N) { +- data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); +- } +- } +- } +- } +- } +-#endif // MUL_MAT_ID +-#else + [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { + [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { + +@@ -431,19 +279,21 @@ void main() { + const uint row_i = dc_warp + cc; + if (row_i >= _ne1) break; + +- const u16vec2 row_idx = row_ids[row_i]; ++ const u16vec2 row_idx = row_ids[row_i - ic * BN]; + #endif // MUL_MAT_ID + [[unroll]] for (uint cr = 0; cr < TM; cr++) { ++ const uint sums_idx = (wsic * TN + cc) * WMITER * TM + wsir * TM + cr; + #ifdef MUL_MAT_ID +- data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); ++ if (dr_warp + cr < p.M) { ++ data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[sums_idx].x); ++ } + #else + if (dr_warp + cr < p.M && dc_warp + cc < p.N) { +- data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); ++ data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[sums_idx].x); + } + #endif // MUL_MAT_ID + } + } + } + } +-#endif // COOPMAT + } +diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +index fe71eb131..c0c03fedc 100644 +--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl ++++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +@@ -6,41 +6,89 @@ + + // Each iqs value maps to a 32-bit integer + +-#if defined(DATA_A_Q4_0) ++#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) ++// 2-byte loads for Q4_0 blocks (18 bytes) ++// 4-byte loads for Q4_1 blocks (20 bytes) + i32vec2 repack(uint ib, uint iqs) { +- // Use 2-byte loads since a q4_0 block (18 bytes) is not divisible by 4 +- const u16vec2 quants = u16vec2(data_a[ib].qs[iqs * 2 ], +- data_a[ib].qs[iqs * 2 + 1]); ++#ifdef DATA_A_Q4_0 ++ const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ], ++ data_a_packed16[ib].qs[iqs * 2 + 1]); + const uint32_t vui = pack32(quants); + return i32vec2( vui & 0x0F0F0F0F, + (vui >> 4) & 0x0F0F0F0F); ++#else // DATA_A_Q4_1 ++ const uint32_t vui = data_a_packed32[ib].qs[iqs]; ++ return i32vec2( vui & 0x0F0F0F0F, ++ (vui >> 4) & 0x0F0F0F0F); ++#endif + } + ++#ifdef DATA_A_Q4_0 + ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { + return ACC_TYPE(da * (float(q_sum) * dsb.x - (8 / sum_divisor) * dsb.y)); + } ++#else // DATA_A_Q4_1 ++ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) { ++ return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor); ++} + #endif + +-#if defined(DATA_A_Q4_1) +-i32vec2 repack(uint ib, uint iqs) { +- // Use 4-byte loads since a q4_1 block (20 bytes) is divisible by 4 +- const uint32_t vui = data_a_packed32[ib].qs[iqs]; +- return i32vec2( vui & 0x0F0F0F0F, +- (vui >> 4) & 0x0F0F0F0F); ++#ifdef MMQ_SHMEM ++void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { ++#ifdef DATA_A_Q4_0 ++ buf_a[buf_ib].qs[iqs] = pack32(u16vec2(data_a_packed16[ib].qs[iqs * 2], ++ data_a_packed16[ib].qs[iqs * 2 + 1])); ++ ++ if (iqs == 0) { ++ buf_a[buf_ib].dm = FLOAT_TYPE(data_a_packed16[ib].d); ++ } ++#else // DATA_A_Q4_1 ++ buf_a[buf_ib].qs[iqs] = data_a_packed32[ib].qs[iqs]; ++ ++ if (iqs == 0) { ++ buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib].dm); ++ } ++#endif + } + +-ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) { +- return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor); ++void block_a_to_registers(const uint reg_ib, const uint buf_ib) { ++ cache_a[reg_ib].dm = buf_a[buf_ib].dm; ++ ++ [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) { ++ cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs]; ++ } + } +-#endif + +-#if defined(DATA_A_Q5_0) ++ACC_TYPE mmq_dot_product(const uint ib_a) { ++ int32_t q_sum = 0; ++ [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) { ++ const uint32_t vui = cache_a[ib_a].qs[iqs]; ++ const i32vec2 qs_a = i32vec2( vui & 0x0F0F0F0F, ++ (vui >> 4) & 0x0F0F0F0F); ++ ++ const int32_t qs_b0 = cache_b.qs[iqs]; ++ const int32_t qs_b1 = cache_b.qs[iqs + 4]; ++ ++ q_sum += dotPacked4x8EXT(qs_a.x, qs_b0); ++ q_sum += dotPacked4x8EXT(qs_a.y, qs_b1); ++ } ++ ++ return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1); ++} ++#endif // MMQ_SHMEM ++ ++#elif defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) ++// 2-byte loads for Q5_0 blocks (22 bytes) ++// 4-byte loads for Q5_1 blocks (24 bytes) + i32vec2 repack(uint ib, uint iqs) { +- // Use 2-byte loads since a q5_0 block (22 bytes) is not divisible by 4 +- const u16vec2 quants = u16vec2(data_a[ib].qs[iqs * 2 ], +- data_a[ib].qs[iqs * 2 + 1]); ++ const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ], ++ data_a_packed16[ib].qs[iqs * 2 + 1]); + const uint32_t vui = pack32(quants); +- const int32_t qh = int32_t((uint32_t(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0]) >> (4 * iqs)); ++#ifdef DATA_A_Q5_0 ++ const int32_t qh = int32_t((uint32_t(data_a_packed16[ib].qh[1]) << 16 | data_a_packed16[ib].qh[0]) >> (4 * iqs)); ++#else // DATA_A_Q5_1 ++ const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs)); ++#endif + const int32_t v0 = int32_t(vui & 0x0F0F0F0F) + | ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28) + +@@ -50,40 +98,457 @@ i32vec2 repack(uint ib, uint iqs) { + return i32vec2(v0, v1); + } + ++#ifdef DATA_A_Q5_0 + ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { + return ACC_TYPE(da * (float(q_sum) * dsb.x - (16 / sum_divisor) * dsb.y)); + } ++#else // DATA_A_Q5_1 ++ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) { ++ return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor); ++} + #endif + +-#if defined(DATA_A_Q5_1) +-i32vec2 repack(uint ib, uint iqs) { +- // Use 4-byte loads since a q5_1 block (24 bytes) is divisible by 4 +- const uint32_t vui = data_a_packed32[ib].qs[iqs]; +- const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs)); +- const int32_t v0 = int32_t(vui & 0x0F0F0F0F) +- | ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28) ++#ifdef MMQ_SHMEM ++void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { ++#ifdef DATA_A_Q5_0 ++ buf_a[buf_ib].qs[iqs] = pack32(u16vec2(data_a_packed16[ib].qs[iqs * 2], ++ data_a_packed16[ib].qs[iqs * 2 + 1])); + +- const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F) +- | (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28) ++ if (iqs == 0) { ++ buf_a[buf_ib].dm = FLOAT_TYPE(data_a_packed16[ib].d); ++ buf_a[buf_ib].qh = pack32(u16vec2(data_a_packed16[ib].qh[0], data_a_packed16[ib].qh[1])); ++ } ++#else // DATA_A_Q5_1 ++ buf_a[buf_ib].qs[iqs] = data_a_packed32[ib].qs[iqs]; + +- return i32vec2(v0, v1); ++ if (iqs == 0) { ++ buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib].dm); ++ buf_a[buf_ib].qh = data_a_packed32[ib].qh; ++ } ++#endif + } + +-ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) { +- return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor); ++void block_a_to_registers(const uint reg_ib, const uint buf_ib) { ++ cache_a[reg_ib].dm = buf_a[buf_ib].dm; ++ cache_a[reg_ib].qh = buf_a[buf_ib].qh; ++ ++ [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) { ++ cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs]; ++ } + } ++ ++ACC_TYPE mmq_dot_product(const uint ib_a) { ++ int32_t q_sum = 0; ++ [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) { ++ const uint32_t vui = cache_a[ib_a].qs[iqs]; ++ const int32_t qh = int32_t(cache_a[ib_a].qh >> (4 * iqs)); ++ const int32_t qs_a0 = int32_t(vui & 0x0F0F0F0F) ++ | ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28) ++ const int32_t qs_a1 = int32_t((vui >> 4) & 0x0F0F0F0F) ++ | (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28) ++ ++ const int32_t qs_b0 = cache_b.qs[iqs]; ++ const int32_t qs_b1 = cache_b.qs[iqs + 4]; ++ ++ q_sum += dotPacked4x8EXT(qs_a0, qs_b0); ++ q_sum += dotPacked4x8EXT(qs_a1, qs_b1); ++ } ++ ++ return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1); ++} ++#endif // MMQ_SHMEM + #endif + + #if defined(DATA_A_Q8_0) ++// 2-byte loads for Q8_0 blocks (34 bytes) + int32_t repack(uint ib, uint iqs) { +- // Use 2-byte loads since a q8_0 block (34 bytes) is not divisible by 4 +- return pack32(i16vec2(data_a[ib].qs[iqs * 2 ], +- data_a[ib].qs[iqs * 2 + 1])); ++ return pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2 ], ++ data_a_packed16[ib].qs[iqs * 2 + 1])); + } + + ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { + return ACC_TYPE(float(q_sum) * da * dsb.x); + } ++ ++#ifdef MMQ_SHMEM ++void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { ++ buf_a[buf_ib].qs[iqs] = pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2], ++ data_a_packed16[ib].qs[iqs * 2 + 1])); ++ ++ if (iqs == 0) { ++ buf_a[buf_ib].dm = FLOAT_TYPE(data_a_packed16[ib].d); ++ } ++} ++ ++void block_a_to_registers(const uint reg_ib, const uint buf_ib) { ++ cache_a[reg_ib].dm = buf_a[buf_ib].dm; ++ ++ [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) { ++ cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs]; ++ } ++} ++ ++ACC_TYPE mmq_dot_product(const uint ib_a) { ++ int32_t q_sum = 0; ++ [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) { ++ const int32_t qs_a = cache_a[ib_a].qs[iqs]; ++ const int32_t qs_b = cache_b.qs[iqs]; ++ ++ q_sum += dotPacked4x8EXT(qs_a, qs_b); ++ } ++ ++ return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1); ++} ++#endif // MMQ_SHMEM ++#endif ++ ++#if defined(DATA_A_MXFP4) ++// 1-byte loads for mxfp4 blocks (17 bytes) ++i32vec2 repack(uint ib, uint iqs) { ++ const uint32_t quants = pack32(u8vec4(data_a[ib].qs[iqs * 4 ], ++ data_a[ib].qs[iqs * 4 + 1], ++ data_a[ib].qs[iqs * 4 + 2], ++ data_a[ib].qs[iqs * 4 + 3])); ++ ++ return i32vec2( quants & 0x0F0F0F0F, ++ (quants >> 4) & 0x0F0F0F0F); ++} ++ ++ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { ++ return ACC_TYPE(da * dsb.x * float(q_sum)); ++} ++ ++#ifdef MMQ_SHMEM ++void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { ++ const uint32_t qs = pack32(u8vec4(data_a[ib].qs[iqs * 4 ], ++ data_a[ib].qs[iqs * 4 + 1], ++ data_a[ib].qs[iqs * 4 + 2], ++ data_a[ib].qs[iqs * 4 + 3])); ++ ++ const u8vec4 i_a0 = unpack8( qs & 0x0F0F0F0F); ++ const u8vec4 i_a1 = unpack8((qs >> 4) & 0x0F0F0F0F); ++ ++ buf_a[buf_ib].qs[iqs ] = pack32(i8vec4(kvalues_mxfp4[i_a0.x], kvalues_mxfp4[i_a0.y], kvalues_mxfp4[i_a0.z], kvalues_mxfp4[i_a0.w])); ++ buf_a[buf_ib].qs[iqs + 4] = pack32(i8vec4(kvalues_mxfp4[i_a1.x], kvalues_mxfp4[i_a1.y], kvalues_mxfp4[i_a1.z], kvalues_mxfp4[i_a1.w])); ++ ++ if (iqs == 0) { ++ buf_a[buf_ib].d = FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e) * 0.5); ++ } ++} ++ ++void block_a_to_registers(const uint reg_ib, const uint buf_ib) { ++ cache_a[reg_ib].d = buf_a[buf_ib].d; ++ ++ [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) { ++ cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs]; ++ } ++} ++ ++ACC_TYPE mmq_dot_product(const uint ib_a) { ++ int32_t q_sum = 0; ++ [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) { ++ const int32_t qs_a = cache_a[ib_a].qs[iqs]; ++ ++ q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]); ++ } ++ ++ return mul_q8_1(q_sum, cache_a[ib_a].d, cache_b.ds, 1); ++} ++#endif // MMQ_SHMEM ++#endif ++ ++// For k-quants, ib and iqs still assume 32-wide blocks, but k-quants are 256-wide ++// iqs still refers to a 32-bit integer, meaning 0..7 for 32-wide quants ++#if defined(DATA_A_Q2_K) ++// 4-byte loads for Q2_K blocks (84 bytes) ++int32_t repack(uint ib, uint iqs) { ++ const uint ib_k = ib / 8; ++ const uint iqs_k = (ib % 8) * 8 + iqs; ++ ++ const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8); ++ const uint qs_shift = ((iqs_k % 32) / 8) * 2; ++ ++ return int32_t((data_a_packed32[ib_k].qs[qs_idx] >> qs_shift) & 0x03030303); ++} ++ ++uint8_t get_scale(uint ib, uint iqs) { ++ const uint ib_k = ib / 8; ++ const uint iqs_k = (ib % 8) * 8 + iqs; ++ ++ return data_a[ib_k].scales[iqs_k / 4]; ++} ++ ++ACC_TYPE mul_q8_1(const int32_t sum_d, const int32_t sum_m, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) { ++ return ACC_TYPE(dsb.x * (dma.x * float(sum_d) - dma.y * float(sum_m))); ++} ++ ++#ifdef MMQ_SHMEM ++void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { ++ const uint ib_k = ib / 8; ++ const uint iqs_k = (ib % 8) * 8 + iqs * QUANT_R_MMQ; ++ ++ const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8); ++ const uint qs_shift = ((iqs_k % 32) / 8) * 2; ++ ++ // Repack 4x4 quants into one int ++ const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x03030303; ++ const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x03030303; ++ const uint32_t vals2 = (data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x03030303; ++ const uint32_t vals3 = (data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x03030303; ++ ++ buf_a[buf_ib].qs[iqs] = vals0 | (vals1 << 2) | (vals2 << 4) | (vals3 << 6); ++ ++ if (iqs == 0) { ++ buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm); ++ buf_a[buf_ib].scales = unpack8(data_a_packed16[ib_k].scales[iqs_k / 8]); ++ } ++} ++ ++void block_a_to_registers(const uint reg_ib, const uint buf_ib) { ++ cache_a[reg_ib].dm = buf_a[buf_ib].dm; ++ cache_a[reg_ib].scales = buf_a[buf_ib].scales; ++ ++ [[unroll]] for (uint iqs = 0; iqs < 2; iqs++) { ++ cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs]; ++ } ++} ++ ++ACC_TYPE mmq_dot_product(const uint ib_a) { ++ int32_t sum_d = 0; ++ int32_t sum_m = 0; ++ ++ [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) { ++ const uint8_t scale = cache_a[ib_a].scales[iqs / 4]; ++ const int32_t scale_m = int32_t(scale >> 4) * 0x01010101; // Duplicate 8-bit value across 32-bits. ++ const int32_t qs_a = int32_t((cache_a[ib_a].qs[iqs / 4] >> ((iqs % 4) * 2)) & 0x03030303); ++ ++ sum_d += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]) * (scale & 0xF); ++ sum_m += dotPacked4x8EXT(scale_m, cache_b.qs[iqs]); ++ } ++ ++ return mul_q8_1(sum_d, sum_m, cache_a[ib_a].dm, cache_b.ds, 1); ++} ++#endif // MMQ_SHMEM ++#endif ++ ++#if defined(DATA_A_Q3_K) ++// 2-byte loads for Q3_K blocks (110 bytes) ++#ifdef MMQ_SHMEM ++void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { ++ const uint ib_k = ib / 8; ++ const uint hm_idx = iqs * QUANT_R_MMQ; ++ const uint iqs_k = (ib % 8) * 8 + hm_idx; ++ ++ const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8); ++ const uint qs_shift = ((iqs_k % 32) / 8) * 2; ++ const uint hm_shift = iqs_k / 8; ++ ++ // Repack 2x4 quants into one int ++ // Add the 3rd bit instead of subtracting it to allow packing the quants ++ const i8vec2 vals00 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 ] >> qs_shift) & uint16_t(0x0303))) | ++ unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 ] >> hm_shift) & uint16_t(0x0101)) << 2)); ++ const i8vec2 vals01 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 1 ] >> qs_shift) & uint16_t(0x0303))) | ++ unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 1] >> hm_shift) & uint16_t(0x0101)) << 2)); ++ const i8vec2 vals10 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 2 ] >> qs_shift) & uint16_t(0x0303))) | ++ unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 2] >> hm_shift) & uint16_t(0x0101)) << 2)); ++ const i8vec2 vals11 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 3 ] >> qs_shift) & uint16_t(0x0303))) | ++ unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 3] >> hm_shift) & uint16_t(0x0101)) << 2)); ++ buf_a[buf_ib].qs[iqs] = pack32(u8vec4(vals00.x, vals00.y, vals01.x, vals01.y)) | ++ (pack32(u8vec4(vals10.x, vals10.y, vals11.x, vals11.y)) << 4); ++ ++ if (iqs == 0) { ++ const uint is = iqs_k / 4; ++ const i8vec2 scales = i8vec2(unpack8(((data_a_packed16[ib_k].scales[(is % 8 ) / 2] >> (4 * (is / 8))) & 0x0F0F) | ++ (((data_a_packed16[ib_k].scales[(8 + (is % 4)) / 2] >> (2 * (is / 4))) & 0x0303) << 4))); ++ ++ buf_a[buf_ib].d_scales = FLOAT_TYPE(data_a_packed16[ib_k].d) * FLOAT_TYPE_VEC2(scales - 32); ++ } ++} ++ ++void block_a_to_registers(const uint reg_ib, const uint buf_ib) { ++ cache_a[reg_ib].d_scales = buf_a[buf_ib].d_scales; ++ ++ [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) { ++ cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs]; ++ } ++} ++ ++ACC_TYPE mmq_dot_product(const uint ib_a) { ++ float result = 0.0; ++ int32_t q_sum = 0; ++ ++ [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) { ++ // Subtract 4 from the quants to correct the 3rd bit offset ++ const int32_t qs_a = pack32(unpack8(int32_t((cache_a[ib_a].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F)) - int8_t(4)); ++ ++ q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]); ++ } ++ result += float(cache_a[ib_a].d_scales[0]) * float(q_sum); ++ q_sum = 0; ++ ++ [[unroll]] for (uint iqs = 4; iqs < 8; iqs++) { ++ const int32_t qs_a = pack32(unpack8(int32_t((cache_a[ib_a].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F)) - int8_t(4)); ++ ++ q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]); ++ } ++ result += float(cache_a[ib_a].d_scales[1]) * float(q_sum); ++ ++ return ACC_TYPE(cache_b.ds.x * result); ++} ++#endif // MMQ_SHMEM ++#endif ++ ++#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K) ++// 4-byte loads for Q4_K blocks (144 bytes) and Q5_K blocks (176 bytes) ++ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) { ++ return ACC_TYPE(dsb.x * dma.x * float(q_sum) - dma.y * dsb.y); ++} ++ ++#ifdef MMQ_SHMEM ++void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { ++ const uint ib_k = ib / 8; ++ const uint iqs_k = (ib % 8) * 8 + iqs * QUANT_R_MMQ; ++ ++ const uint qs_idx = (iqs_k / 16) * 8 + (iqs_k % 8); ++ const uint qs_shift = ((iqs_k % 16) / 8) * 4; ++ ++ // Repack 2x4 quants into one int ++#if defined(DATA_A_Q4_K) ++ const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x0F0F0F0F; ++ const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x0F0F0F0F; ++ ++ buf_a[buf_ib].qs[iqs] = vals0 | (vals1 << 4); ++#else // defined(DATA_A_Q5_K) ++ const uint qh_idx = iqs * QUANT_R_MMQ; ++ const uint qh_shift = iqs_k / 8; ++ ++ buf_a[buf_ib].qs[iqs] = int32_t(((data_a_packed32[ib_k].qs[qs_idx] >> qs_shift) & 0x0F0F0F0F) | ++ (((data_a_packed32[ib_k].qh[qh_idx] >> qh_shift) & 0x01010101) << 4)); ++#endif ++ ++ ++ if (iqs == 0) { ++ // Scale index ++ const uint is = iqs_k / 8; ++ u8vec2 scale_dm; ++ if (is < 4) { ++ scale_dm = u8vec2(data_a[ib_k].scales[is] & 0x3F, data_a[ib_k].scales[is + 4] & 0x3F); ++ } else { ++ scale_dm = u8vec2((data_a[ib_k].scales[is+4] & 0xF) | ((data_a[ib_k].scales[is-4] & 0xC0) >> 2), ++ (data_a[ib_k].scales[is+4] >> 4) | ((data_a[ib_k].scales[is ] & 0xC0) >> 2)); ++ } ++ ++ buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm) * FLOAT_TYPE_VEC2(scale_dm); ++ } ++} ++ ++void block_a_to_registers(const uint reg_ib, const uint buf_ib) { ++ cache_a[reg_ib].dm = buf_a[buf_ib].dm; ++ ++ [[unroll]] for (uint iqs = 0; iqs < 8 / QUANT_R_MMQ; iqs++) { ++ cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs]; ++ } ++} ++ ++ACC_TYPE mmq_dot_product(const uint ib_a) { ++ int32_t q_sum = 0; ++ ++ [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) { ++#if defined(DATA_A_Q4_K) ++ const int32_t qs_a = int32_t((cache_a[ib_a].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F); ++#else // defined(DATA_A_Q5_K) ++ const int32_t qs_a = cache_a[ib_a].qs[iqs]; ++#endif ++ ++ q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]); ++ } ++ ++ return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1); ++} ++#endif // MMQ_SHMEM ++#endif ++ ++#ifdef MMQ_SHMEM ++void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { ++ const uint ib_outer = ib / 4; ++ const uint ib_inner = ib % 4; ++ ++ if (iqs == 0) { ++ buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]); ++ } ++ ++ const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs]; ++ buf_b[buf_ib].qs[iqs * 4 ] = values.x; ++ buf_b[buf_ib].qs[iqs * 4 + 1] = values.y; ++ buf_b[buf_ib].qs[iqs * 4 + 2] = values.z; ++ buf_b[buf_ib].qs[iqs * 4 + 3] = values.w; ++} ++ ++void block_b_to_registers(const uint ib) { ++ cache_b.ds = buf_b[ib].ds; ++ [[unroll]] for (uint iqs = 0; iqs < BK / 4; iqs++) { ++ cache_b.qs[iqs] = buf_b[ib].qs[iqs]; ++ } ++} ++#endif ++ ++#if defined(DATA_A_Q6_K) ++// 2-byte loads for Q6_K blocks (210 bytes) ++#ifdef MMQ_SHMEM ++void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { ++ const uint ib_k = ib / 8; ++ const uint iqs_k = (ib % 8) * 8 + iqs; ++ ++ const uint ql_idx = (iqs_k / 32) * 16 + iqs_k % 16; ++ const uint ql_shift = ((iqs_k % 32) / 16) * 4; ++ ++ const uint qh_idx = (iqs_k / 32) * 8 + iqs; ++ const uint qh_shift = ((iqs_k % 32) / 8) * 2; ++ ++ const i8vec2 vals00 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 ] >> ql_shift) & uint16_t(0x0F0F))) | ++ unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 ] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); ++ const i8vec2 vals01 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 1] >> ql_shift) & uint16_t(0x0F0F))) | ++ unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 1] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); ++ buf_a[buf_ib].qs[iqs] = pack32(i8vec4(vals00.x, vals00.y, vals01.x, vals01.y)); ++ ++ if (iqs == 0) { ++ const uint is = iqs_k / 4; ++ const i8vec2 scales = unpack8(data_a_packed16[ib_k].scales[is / 2]); ++ ++ buf_a[buf_ib].d_scales = FLOAT_TYPE(data_a_packed16[ib_k].d) * FLOAT_TYPE_VEC2(scales); ++ } ++} ++ ++void block_a_to_registers(const uint reg_ib, const uint buf_ib) { ++ cache_a[reg_ib].d_scales = buf_a[buf_ib].d_scales; ++ ++ [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) { ++ cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs]; ++ } ++} ++ ++ACC_TYPE mmq_dot_product(const uint ib_a) { ++ float result = 0.0; ++ int32_t q_sum = 0; ++ ++ [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) { ++ const int32_t qs_a = cache_a[ib_a].qs[iqs]; ++ ++ q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]); ++ } ++ result += float(cache_a[ib_a].d_scales[0]) * float(q_sum); ++ q_sum = 0; ++ ++ [[unroll]] for (uint iqs = 4; iqs < 8; iqs++) { ++ const int32_t qs_a = cache_a[ib_a].qs[iqs]; ++ ++ q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]); ++ } ++ result += float(cache_a[ib_a].d_scales[1]) * float(q_sum); ++ ++ return ACC_TYPE(cache_b.ds.x * result); ++} ++#endif // MMQ_SHMEM + #endif + + #if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL) +@@ -103,3 +568,10 @@ FLOAT_TYPE_VEC2 get_dm(uint ib) { + return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm); + } + #endif ++ ++#if defined(DATA_A_Q2_K) ++FLOAT_TYPE_VEC2 get_dm(uint ib) { ++ const uint ib_k = ib / 8; ++ return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm); ++} ++#endif +diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +new file mode 100644 +index 000000000..72fec4404 +--- /dev/null ++++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +@@ -0,0 +1,78 @@ ++#if defined(DATA_A_Q4_0) ++#define QUANT_R_MMQ 2 ++struct block_a_cache { ++ uint32_t qs[16/4]; ++ FLOAT_TYPE dm; ++}; ++#elif defined(DATA_A_Q4_1) ++#define QUANT_R_MMQ 2 ++struct block_a_cache { ++ uint32_t qs[16/4]; ++ FLOAT_TYPE_VEC2 dm; ++}; ++#elif defined(DATA_A_Q5_0) ++#define QUANT_R_MMQ 2 ++struct block_a_cache { ++ uint32_t qs[16/4]; ++ uint32_t qh; ++ FLOAT_TYPE dm; ++}; ++#elif defined(DATA_A_Q5_1) ++#define QUANT_R_MMQ 2 ++struct block_a_cache { ++ uint32_t qs[16/4]; ++ uint32_t qh; ++ FLOAT_TYPE_VEC2 dm; ++}; ++#elif defined(DATA_A_Q8_0) ++#define QUANT_R_MMQ 1 ++// AMD likes 4, Intel likes 1 and Nvidia likes 2 ++#define BK_STEP 1 ++struct block_a_cache { ++ int32_t qs[32/4]; ++ FLOAT_TYPE dm; ++}; ++#elif defined(DATA_A_MXFP4) ++#define QUANT_R_MMQ 2 ++struct block_a_cache { ++ int32_t qs[8]; ++ FLOAT_TYPE d; ++}; ++#elif defined(DATA_A_Q2_K) ++#define QUANT_R_MMQ 4 ++struct block_a_cache { ++ uint32_t qs[2]; ++ u8vec2 scales; ++ FLOAT_TYPE_VEC2 dm; ++}; ++#elif defined(DATA_A_Q3_K) ++#define QUANT_R_MMQ 2 ++struct block_a_cache { ++ uint32_t qs[4]; ++ FLOAT_TYPE_VEC2 d_scales; ++}; ++#elif defined(DATA_A_Q4_K) ++#define QUANT_R_MMQ 2 ++struct block_a_cache { ++ uint32_t qs[4]; ++ FLOAT_TYPE_VEC2 dm; ++}; ++#elif defined(DATA_A_Q5_K) ++#define QUANT_R_MMQ 1 ++struct block_a_cache { ++ int32_t qs[8]; ++ FLOAT_TYPE_VEC2 dm; ++}; ++#elif defined(DATA_A_Q6_K) ++#define QUANT_R_MMQ 1 ++struct block_a_cache { ++ int32_t qs[8]; ++ FLOAT_TYPE_VEC2 d_scales; ++}; ++#endif ++ ++struct block_b_cache ++{ ++ int32_t qs[8]; ++ FLOAT_TYPE_VEC2 ds; ++}; +diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +index 2fa54ce51..02578c77c 100644 +--- a/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl ++++ b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +@@ -66,6 +66,7 @@ struct block_q4_0_packed16 + #define QUANT_AUXF 1 + #define A_TYPE block_q4_0 + #define A_TYPE_PACKED16 block_q4_0_packed16 ++#define DATA_A_QUANT_LEGACY + #endif + + #define QUANT_K_Q4_1 32 +@@ -98,6 +99,7 @@ struct block_q4_1_packed32 + #define A_TYPE block_q4_1 + #define A_TYPE_PACKED16 block_q4_1_packed16 + #define A_TYPE_PACKED32 block_q4_1_packed32 ++#define DATA_A_QUANT_LEGACY + #endif + + #define QUANT_K_Q5_0 32 +@@ -123,6 +125,7 @@ struct block_q5_0_packed16 + #define QUANT_AUXF 1 + #define A_TYPE block_q5_0 + #define A_TYPE_PACKED16 block_q5_0_packed16 ++#define DATA_A_QUANT_LEGACY + #endif + + #define QUANT_K_Q5_1 32 +@@ -158,6 +161,7 @@ struct block_q5_1_packed32 + #define A_TYPE block_q5_1 + #define A_TYPE_PACKED16 block_q5_1_packed16 + #define A_TYPE_PACKED32 block_q5_1_packed32 ++#define DATA_A_QUANT_LEGACY + #endif + + #define QUANT_K_Q8_0 32 +@@ -186,6 +190,7 @@ struct block_q8_0_packed32 + #define A_TYPE block_q8_0 + #define A_TYPE_PACKED16 block_q8_0_packed16 + #define A_TYPE_PACKED32 block_q8_0_packed32 ++#define DATA_A_QUANT_LEGACY + #endif + + #define QUANT_K_Q8_1 32 +@@ -226,21 +231,21 @@ struct block_q2_K + { + uint8_t scales[QUANT_K_Q2_K/16]; + uint8_t qs[QUANT_K_Q2_K/4]; +- f16vec2 d; ++ f16vec2 dm; + }; + + struct block_q2_K_packed16 + { + uint16_t scales[QUANT_K_Q2_K/16/2]; + uint16_t qs[QUANT_K_Q2_K/4/2]; +- f16vec2 d; ++ f16vec2 dm; + }; + + struct block_q2_K_packed32 + { + uint32_t scales[QUANT_K_Q2_K/16/4]; + uint32_t qs[QUANT_K_Q2_K/4/4]; +- f16vec2 d; ++ f16vec2 dm; + }; + + #if defined(DATA_A_Q2_K) +@@ -249,6 +254,8 @@ struct block_q2_K_packed32 + #define A_TYPE block_q2_K + #define A_TYPE_PACKED16 block_q2_K_packed16 + #define A_TYPE_PACKED32 block_q2_K_packed32 ++#define SCALES_PER_32 2 ++#define DATA_A_QUANT_K + #endif + + #define QUANT_K_Q3_K 256 +@@ -274,27 +281,28 @@ struct block_q3_K_packed16 + #define QUANT_R 1 + #define A_TYPE block_q3_K + #define A_TYPE_PACKED16 block_q3_K_packed16 ++#define DATA_A_QUANT_K + #endif + + #define QUANT_K_Q4_K 256 + + struct block_q4_K + { +- f16vec2 d; ++ f16vec2 dm; + uint8_t scales[3*QUANT_K_Q4_K/64]; + uint8_t qs[QUANT_K_Q4_K/2]; + }; + + struct block_q4_K_packed16 + { +- f16vec2 d; ++ f16vec2 dm; + uint16_t scales[3*QUANT_K_Q4_K/64/2]; + uint16_t qs[QUANT_K_Q4_K/2/2]; + }; + + struct block_q4_K_packed32 + { +- f16vec2 d; ++ f16vec2 dm; + uint32_t scales[3*QUANT_K_Q4_K/64/4]; + uint32_t qs[QUANT_K_Q4_K/2/4]; + }; +@@ -310,13 +318,14 @@ struct block_q4_K_packed128 + #define A_TYPE block_q4_K + #define A_TYPE_PACKED16 block_q4_K_packed16 + #define A_TYPE_PACKED32 block_q4_K_packed32 ++#define DATA_A_QUANT_K + #endif + + #define QUANT_K_Q5_K 256 + + struct block_q5_K + { +- f16vec2 d; ++ f16vec2 dm; + uint8_t scales[12]; + uint8_t qh[QUANT_K_Q5_K/8]; + uint8_t qs[QUANT_K_Q5_K/2]; +@@ -324,12 +333,20 @@ struct block_q5_K + + struct block_q5_K_packed16 + { +- f16vec2 d; ++ f16vec2 dm; + uint16_t scales[12/2]; + uint16_t qh[QUANT_K_Q5_K/8/2]; + uint16_t qs[QUANT_K_Q5_K/2/2]; + }; + ++struct block_q5_K_packed32 ++{ ++ f16vec2 dm; ++ uint32_t scales[12/4]; ++ uint32_t qh[QUANT_K_Q5_K/8/4]; ++ uint32_t qs[QUANT_K_Q5_K/2/4]; ++}; ++ + struct block_q5_K_packed128 + { + uvec4 q5k[11]; +@@ -340,6 +357,8 @@ struct block_q5_K_packed128 + #define QUANT_R 1 + #define A_TYPE block_q5_K + #define A_TYPE_PACKED16 block_q5_K_packed16 ++#define A_TYPE_PACKED32 block_q5_K_packed32 ++#define DATA_A_QUANT_K + #endif + + #define QUANT_K_Q6_K 256 +@@ -356,7 +375,7 @@ struct block_q6_K_packed16 + { + uint16_t ql[QUANT_K_Q6_K/2/2]; + uint16_t qh[QUANT_K_Q6_K/4/2]; +- int8_t scales[QUANT_K_Q6_K/16]; ++ int16_t scales[QUANT_K_Q6_K/16/2]; + float16_t d; + }; + +@@ -365,6 +384,7 @@ struct block_q6_K_packed16 + #define QUANT_R 1 + #define A_TYPE block_q6_K + #define A_TYPE_PACKED16 block_q6_K_packed16 ++#define DATA_A_QUANT_K + #endif + + // IQuants +@@ -1363,18 +1383,11 @@ struct block_mxfp4 + uint8_t qs[QUANT_K_MXFP4/2]; + }; + +-//struct block_mxfp4_packed16 +-//{ +-// uint8_t e; +-// uint16_t qs[QUANT_K_MXFP4/2/2]; +-//}; +- + #if defined(DATA_A_MXFP4) + #define QUANT_K QUANT_K_MXFP4 + #define QUANT_R QUANT_R_MXFP4 + #define QUANT_AUXF 1 + #define A_TYPE block_mxfp4 +-//#define A_TYPE_PACKED16 block_mxfp4_packed16 + #endif + + #if defined(DATA_A_IQ4_NL) || defined(DATA_A_IQ4_XS) +@@ -1397,12 +1410,12 @@ void init_iq_shmem(uvec3 wgsize) + #endif + + #if defined(DATA_A_MXFP4) +-const FLOAT_TYPE kvalues_mxfp4_const[16] = { +- FLOAT_TYPE(0.0f), FLOAT_TYPE(0.5f), FLOAT_TYPE(1.0f), FLOAT_TYPE(1.5f), FLOAT_TYPE(2.0f), FLOAT_TYPE(3.0f), FLOAT_TYPE(4.0f), FLOAT_TYPE(6.0f), +- FLOAT_TYPE(-0.0f), FLOAT_TYPE(-0.5f), FLOAT_TYPE(-1.0f), FLOAT_TYPE(-1.5f), FLOAT_TYPE(-2.0f), FLOAT_TYPE(-3.0f), FLOAT_TYPE(-4.0f), FLOAT_TYPE(-6.0f) ++const int8_t kvalues_mxfp4_const[16] = { ++ int8_t(0), int8_t(1), int8_t(2), int8_t(3), int8_t(4), int8_t(6), int8_t(8), int8_t(12), ++ int8_t(0), int8_t(-1), int8_t(-2), int8_t(-3), int8_t(-4), int8_t(-6), int8_t(-8), int8_t(-12), + }; + +-shared FLOAT_TYPE kvalues_mxfp4[16]; ++shared int8_t kvalues_mxfp4[16]; + + #define NEEDS_INIT_IQ_SHMEM + void init_iq_shmem(uvec3 wgsize) +diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +index 0f25ba345..03fa01639 100644 +--- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp ++++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +@@ -566,7 +566,8 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c + } + + #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) +- if (!coopmat && !coopmat2 && matmul_id_type == MatMulIdType::NONE && is_legacy_quant(tname)) { ++ // Integer dot mmq performs better with f32 accumulators ++ if (!f16acc && !coopmat && !coopmat2 && (is_legacy_quant(tname) || is_k_quant(tname) || tname == "mxfp4")) { + string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc); + } + #endif +@@ -574,7 +575,7 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c + } + + void process_shaders() { +- std::map base_dict = {{"FLOAT_TYPE", "float"}}; ++ std::map base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}}; + + // matmul + for (const MatMulIdType& matmul_id_type : {MatMulIdType::NONE, MatMulIdType::DEFAULT, MatMulIdType::SUBGROUP}) { diff --git a/llama/patches/0031-vulkan-Update-topk_moe-fusion-to-handle-gpt-s-late-s.patch b/llama/patches/0031-vulkan-Update-topk_moe-fusion-to-handle-gpt-s-late-s.patch new file mode 100644 index 000000000..41cd9cd55 --- /dev/null +++ b/llama/patches/0031-vulkan-Update-topk_moe-fusion-to-handle-gpt-s-late-s.patch @@ -0,0 +1,657 @@ +From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 +From: Jeff Bolz +Date: Wed, 29 Oct 2025 08:44:29 -0500 +Subject: [PATCH] vulkan: Update topk_moe fusion to handle gpt's late softmax + (#16656) + +* vulkan: Update topk_moe fusion to handle gpt's late softmax + +Based on #16649. + +* Add ggml_check_edges + +* Add sync logging to show fusion effects + +* handle clamp added in #16655 + +* Update ggml/src/ggml-impl.h + +Co-authored-by: Diego Devesa +--- + ggml/src/ggml-impl.h | 16 + + ggml/src/ggml-vulkan/ggml-vulkan.cpp | 304 +++++++++++------- + .../ggml-vulkan/vulkan-shaders/topk_moe.comp | 90 ++++-- + 3 files changed, 272 insertions(+), 138 deletions(-) + +diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h +index 639d551a2..e5c446d1d 100644 +--- a/ggml/src/ggml-impl.h ++++ b/ggml/src/ggml-impl.h +@@ -693,6 +693,7 @@ GGML_API void ggml_dxgi_pdh_release(); + #endif + + #ifdef __cplusplus ++#include + #include + #include + +@@ -708,6 +709,21 @@ inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph, + return ggml_can_fuse_subgraph(cgraph, start_idx, ops.size(), ops.begin(), outputs.begin(), outputs.size()); + } + ++// Return true if the edges in the graph match expectations. ++inline bool ggml_check_edges(const struct ggml_cgraph * cgraph, ++ int start_idx, ++ std::initializer_list> edges) { ++ for (const auto & edge : edges) { ++ int dst_node = edge[0]; ++ int src_idx = edge[1]; ++ int src_node = edge[2]; ++ if (cgraph->nodes[start_idx + dst_node]->src[src_idx] != cgraph->nodes[start_idx + src_node]) { ++ return false; ++ } ++ } ++ return true; ++} ++ + // expose GGUF internals for test code + GGML_API size_t gguf_type_size(enum gguf_type type); + GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params); +diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp +index 53b57c179..b2855b078 100644 +--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp ++++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp +@@ -387,12 +387,76 @@ static constexpr uint32_t num_argsort_pipelines = 11; + static constexpr uint32_t max_argsort_cols = 1 << (num_argsort_pipelines-1); + static constexpr uint32_t num_topk_moe_pipelines = 10; + +-static constexpr std::array topk_moe_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, +- GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE, +- GGML_OP_SUM_ROWS, GGML_OP_DIV, GGML_OP_RESHAPE }; +-static constexpr std::array topk_moe { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, +- GGML_OP_VIEW, GGML_OP_GET_ROWS }; ++static constexpr std::initializer_list topk_moe_early_softmax_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, ++ GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE, ++ GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV, ++ GGML_OP_RESHAPE }; ++static constexpr std::initializer_list topk_moe_early_softmax { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, ++ GGML_OP_VIEW, GGML_OP_GET_ROWS }; ++static constexpr std::initializer_list topk_moe_late_softmax { GGML_OP_ARGSORT, GGML_OP_VIEW, ++ GGML_OP_GET_ROWS, GGML_OP_RESHAPE, ++ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE }; ++ ++//node #978 ( SOFT_MAX): ffn_moe_probs-15 ( 0K) [Vulka ] use=2: ffn_moe_logits-15 ( 0K) [Vulka ] ++//node #979 ( RESHAPE): ffn_moe_probs-15 (re ( 0K) [Vulka ] use=1: ffn_moe_probs-15 ( 0K) [Vulka ] ++//node #980 ( ARGSORT): ffn_moe_argsort-15 ( 0K) [Vulka ] use=1: ffn_moe_probs-15 ( 0K) [Vulka ] ++//node #981 ( VIEW): ffn_moe_topk-15 ( 0K) [Vulka ] use=4: ffn_moe_argsort-15 ( 0K) [Vulka ] ++//node #982 ( GET_ROWS): ffn_moe_weights-15 ( 0K) [Vulka ] use=1: ffn_moe_probs-15 (re ( 0K) [Vulka ] ffn_moe_topk-15 ( 0K) [Vulka ] ++//node #983 ( RESHAPE): ffn_moe_weights-15 ( ( 0K) [Vulka ] use=2: ffn_moe_weights-15 ( 0K) [Vulka ] ++//node #984 ( SUM_ROWS): ffn_moe_weights_sum- ( 0K) [Vulka ] use=1: ffn_moe_weights-15 ( ( 0K) [Vulka ] ++//node #985 ( CLAMP): ffn_moe_weights_sum_ ( 0K) [Vulka ] use=1: ffn_moe_weights_sum- ( 0K) [Vulka ] ++//node #986 ( DIV): ffn_moe_weights_norm ( 0K) [Vulka ] use=1: ffn_moe_weights-15 ( ( 0K) [Vulka ] ffn_moe_weights_sum_ ( 0K) [Vulka ] ++//node #987 ( RESHAPE): ffn_moe_weights_norm ( 0K) [Vulka ] use=1: ffn_moe_weights_norm ( 0K) [Vulka ] ++static constexpr std::initializer_list> topk_moe_early_softmax_norm_edges { ++ { 1, 0, 0 }, // reshape->src[0] == softmax ++ { 2, 0, 0 }, // argsort->src[0] == softmax ++ { 3, 0, 2 }, // view->src[0] == argsort ++ { 4, 0, 1 }, // get_rows->src[0] == reshape ++ { 4, 1, 3 }, // get_rows->src[1] == view ++ { 5, 0, 4 }, // reshape->src[0] == get_rows ++ { 6, 0, 5 }, // sum_rows->src[0] == reshape ++ { 7, 0, 6 }, // clamp->src[0] == sum_rows ++ { 8, 0, 5 }, // div->src[0] == reshape ++ { 8, 1, 7 }, // div->src[1] == clamp ++ { 9, 0, 8 }, // reshape->src[0] == div ++}; ++ ++// same as early_softmax_norm but ending after the get_rows ++static constexpr std::initializer_list> topk_moe_early_softmax_edges { ++ { 1, 0, 0 }, // reshape->src[0] == softmax ++ { 2, 0, 0 }, // argsort->src[0] == softmax ++ { 3, 0, 2 }, // view->src[0] == argsort ++ { 4, 0, 1 }, // get_rows->src[0] == reshape ++ { 4, 1, 3 }, // get_rows->src[1] == view ++}; + ++//node #652 ( ARGSORT): ffn_moe_argsort-11 ( 0K) [Vulka ] use=1: ffn_moe_probs-11 ( 0K) [Vulka ] ++//node #653 ( VIEW): ffn_moe_topk-11 ( 0K) [Vulka ] use=7: ffn_moe_argsort-11 ( 0K) [Vulka ] ++//node #654 ( GET_ROWS): ffn_moe_weights-11 ( 0K) [Vulka ] use=1: ffn_moe_probs-11 (re ( 0K) [Vulka ] ffn_moe_topk-11 ( 0K) [Vulka ] ++//node #655 ( RESHAPE): ffn_moe_weights-11 ( ( 0K) [Vulka ] use=1: ffn_moe_weights-11 ( 0K) [Vulka ] ++//node #656 ( SOFT_MAX): node_656 ( 0K) [Vulka ] use=1: ffn_moe_weights-11 ( ( 0K) [Vulka ] ++//node #657 ( RESHAPE): ffn_moe_weights_soft ( 0K) [Vulka ] use=1: node_656 ( 0K) [Vulka ] ++static constexpr std::initializer_list> topk_moe_late_softmax_edges { ++ { 1, 0, 0 }, // view->src[0] == argsort ++ { 2, 1, 1 }, // get_rows->src[1] == view ++ { 3, 0, 2 }, // reshape->src[0] == get_rows ++ { 4, 0, 3 }, // soft_max->src[0] == reshape ++ { 5, 0, 4 }, // reshape->src[0] == soft_max ++}; ++ ++enum topk_moe_mode { ++ TOPK_MOE_EARLY_SOFTMAX, ++ TOPK_MOE_EARLY_SOFTMAX_NORM, ++ TOPK_MOE_LATE_SOFTMAX, ++ TOPK_MOE_COUNT, ++}; ++ ++static topk_moe_mode ggml_vk_num_additional_ops_to_topk_moe_mode(uint32_t num) { ++ topk_moe_mode mode = num == topk_moe_early_softmax_norm.size() - 1 ? TOPK_MOE_EARLY_SOFTMAX_NORM : ++ num == topk_moe_early_softmax.size() - 1 ? TOPK_MOE_EARLY_SOFTMAX : ++ TOPK_MOE_LATE_SOFTMAX; ++ return mode; ++} + + struct vk_device_struct { + std::recursive_mutex mutex; +@@ -607,8 +671,7 @@ struct vk_device_struct { + + vk_pipeline pipeline_flash_attn_split_k_reduce; + +- // [2] is {!norm, norm} +- vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][2]; ++ vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][TOPK_MOE_COUNT]; + + std::vector all_pipelines; + +@@ -956,6 +1019,8 @@ static_assert(sizeof(vk_op_multi_add_push_constants) <= 256); + struct vk_op_topk_moe_push_constants { + uint32_t n_rows; + uint32_t n_expert_used; ++ float clamp_min; ++ float clamp_max; + }; + + struct vk_op_add_id_push_constants { +@@ -3806,8 +3871,9 @@ static void ggml_vk_load_shaders(vk_device& device) { + ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, "conv2d_dw_cwhn_f16_f32", conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); + + for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) { +- ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][0], "topk_moe_f32_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<pipeline_topk_moe[i][1], "topk_moe_f32_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX], "topk_moe_f32_early_softmax_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX_NORM], "topk_moe_f32_early_softmax_norm"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<pipeline_topk_moe[i][TOPK_MOE_LATE_SOFTMAX], "topk_moe_f32_late_softmax"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<num_additional_fused_ops) { + uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0]))); + GGML_ASSERT(idx < num_topk_moe_pipelines); +- bool with_norm = ctx->num_additional_fused_ops == topk_moe_norm.size() - 1; +- return ctx->device->pipeline_topk_moe[idx][with_norm]; ++ topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops); ++ return ctx->device->pipeline_topk_moe[idx][mode]; + } + + if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) { +@@ -8141,6 +8207,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const + return nullptr; + } + case GGML_OP_ARGSORT: ++ if (ctx->num_additional_fused_ops) { ++ uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0]))); ++ GGML_ASSERT(idx < num_topk_moe_pipelines); ++ topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops); ++ return ctx->device->pipeline_topk_moe[idx][mode]; ++ } ++ + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) { + uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0]))); + return ctx->device->pipeline_argsort_f32[idx]; +@@ -9676,10 +9749,12 @@ static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& sub + + static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx, bool dryrun = false) { + +- bool with_norm = ctx->num_additional_fused_ops == topk_moe_norm.size() - 1; ++ topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops); + ggml_tensor * logits = cgraph->nodes[node_idx + 0]->src[0]; +- ggml_tensor * weights = with_norm ? cgraph->nodes[node_idx + 8] : cgraph->nodes[node_idx + 4]; +- ggml_tensor * ids = cgraph->nodes[node_idx + 3]; ++ ggml_tensor * weights = (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) ? cgraph->nodes[node_idx + 9] : ++ (mode == TOPK_MOE_EARLY_SOFTMAX) ? cgraph->nodes[node_idx + 4] : ++ cgraph->nodes[node_idx + 5]; ++ ggml_tensor * ids = (mode == TOPK_MOE_LATE_SOFTMAX) ? cgraph->nodes[node_idx + 1] : cgraph->nodes[node_idx + 3]; + + GGML_ASSERT(logits->type == GGML_TYPE_F32); + GGML_ASSERT(weights->type == GGML_TYPE_F32); +@@ -9738,9 +9813,14 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, + GGML_ASSERT(d_ids != nullptr); + } + +- vk_op_topk_moe_push_constants pc; ++ vk_op_topk_moe_push_constants pc {}; + pc.n_rows = n_rows; + pc.n_expert_used = n_expert_used; ++ if (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) { ++ ggml_tensor * clamp = cgraph->nodes[node_idx + 7]; ++ pc.clamp_min = ggml_get_op_params_f32(clamp, 0); ++ pc.clamp_max = ggml_get_op_params_f32(clamp, 1); ++ } + + GGML_ASSERT(n_expert_used <= n_experts); + +@@ -11335,7 +11415,13 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr + } + } + } ++ ++#define ENABLE_SYNC_LOGGING 0 ++ + if (need_sync) { ++#if ENABLE_SYNC_LOGGING ++ std::cerr << "sync" << std::endl; ++#endif + ctx->unsynced_nodes_written.clear(); + ctx->unsynced_nodes_read.clear(); + ggml_vk_sync_buffers(ctx, compute_ctx); +@@ -11353,6 +11439,18 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr + } + } + } ++#if ENABLE_SYNC_LOGGING ++ if (!dryrun) { ++ for (int i = 0; i < ctx->num_additional_fused_ops + 1; ++i) { ++ auto *n = cgraph->nodes[node_idx + i]; ++ std::cerr << node_idx + i << " " << ggml_op_name(n->op) << " " << n->name; ++ if (n->op == GGML_OP_GLU) { ++ std::cerr << " " << ggml_glu_op_name(ggml_get_glu_op(n)) << " " << (n->src[1] ? "split" : "single") << " "; ++ } ++ std::cerr << std::endl; ++ } ++ } ++#endif + + switch (node->op) { + case GGML_OP_REPEAT: +@@ -11531,7 +11629,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr + + break; + case GGML_OP_ARGSORT: +- ggml_vk_argsort(ctx, compute_ctx, src0, node, dryrun); ++ if (ctx->num_additional_fused_ops) { ++ ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx, dryrun); ++ } else { ++ ggml_vk_argsort(ctx, compute_ctx, src0, node, dryrun); ++ } + + break; + case GGML_OP_SUM: +@@ -12329,30 +12431,27 @@ static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, st + } + + static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, +- int node_idx, bool with_norm) { ++ int node_idx, topk_moe_mode mode) { + +- if (with_norm) { +- if (node_idx + (int)topk_moe_norm.size() > cgraph->n_nodes) { +- return false; +- } +- for (size_t i = 0; i < topk_moe_norm.size(); ++i) { +- if (cgraph->nodes[node_idx + i]->op != topk_moe_norm[i]) { +- return false; +- } +- } +- } else { +- if (node_idx + (int)topk_moe.size() > cgraph->n_nodes) { +- return false; +- } +- for (size_t i = 0; i < topk_moe.size(); ++i) { +- if (cgraph->nodes[node_idx + i]->op != topk_moe[i]) { +- return false; +- } +- } +- } ++ const ggml_tensor * softmax; ++ const ggml_tensor * weights; + +- const ggml_tensor * softmax = cgraph->nodes[node_idx + 0]; +- const ggml_tensor * weights = with_norm ? cgraph->nodes[node_idx + 8] : cgraph->nodes[node_idx + 4]; ++ switch (mode) { ++ case TOPK_MOE_EARLY_SOFTMAX_NORM: ++ softmax = cgraph->nodes[node_idx + 0]; ++ weights = cgraph->nodes[node_idx + 9]; ++ break; ++ case TOPK_MOE_EARLY_SOFTMAX: ++ softmax = cgraph->nodes[node_idx + 0]; ++ weights = cgraph->nodes[node_idx + 4]; ++ break; ++ case TOPK_MOE_LATE_SOFTMAX: ++ softmax = cgraph->nodes[node_idx + 4]; ++ weights = cgraph->nodes[node_idx + 5]; ++ break; ++ default: ++ return false; ++ } + + const float * op_params = (const float *)softmax->op_params; + +@@ -12378,60 +12477,6 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc + return false; + } + +- // Check that the nodes don't have any unexpected uses +- const ggml_tensor * reshape1 = cgraph->nodes[node_idx + 1]; +- const ggml_tensor * argsort = cgraph->nodes[node_idx + 2]; +- const ggml_tensor * view = cgraph->nodes[node_idx + 3]; +- const ggml_tensor * get_rows = cgraph->nodes[node_idx + 4]; +- const ggml_tensor * reshape5 = with_norm ? cgraph->nodes[node_idx + 5] : nullptr; +- const ggml_tensor * sum_rows = with_norm ? cgraph->nodes[node_idx + 6] : nullptr; +- const ggml_tensor * div = with_norm ? cgraph->nodes[node_idx + 7] : nullptr; +- const ggml_tensor * reshape8 = with_norm ? cgraph->nodes[node_idx + 8] : nullptr; +- +- // softmax is used by reshape and argsort +- if (ggml_node_get_use_count(cgraph, node_idx) != 2 || +- reshape1->src[0] != softmax || +- argsort->src[0] != softmax) { +- return false; +- } +- // reshape is used by get_rows +- if (ggml_node_get_use_count(cgraph, node_idx + 1) != 1 || +- get_rows->src[0] != reshape1) { +- return false; +- } +- // argsort is used by view +- if (ggml_node_get_use_count(cgraph, node_idx + 2) != 1 || +- view->src[0] != argsort) { +- return false; +- } +- // view is written (via argsort), we can skip checking it +- +- if (with_norm) { +- // get_rows is used by reshape +- if (ggml_node_get_use_count(cgraph, node_idx + 4) != 1 || +- reshape5->src[0] != get_rows) { +- return false; +- } +- +- // reshape is used by sum_rows and div +- if (ggml_node_get_use_count(cgraph, node_idx + 5) != 2 || +- sum_rows->src[0] != reshape5 || +- div->src[0] != reshape5) { +- return false; +- } +- +- // sum_rows is used by div +- if (ggml_node_get_use_count(cgraph, node_idx + 6) != 1 || +- div->src[1] != sum_rows) { +- return false; +- } +- +- // div/reshape are written +- if (reshape8->src[0] != div) { +- return false; +- } +- } +- + if (!ctx->device->subgroup_arithmetic || + !ctx->device->subgroup_shuffle || + !ctx->device->subgroup_require_full_support || +@@ -12517,10 +12562,18 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg + ctx->num_additional_fused_ops = num_adds - 1; + } else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { + ctx->num_additional_fused_ops = 1; +- } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, true)) { +- ctx->num_additional_fused_ops = topk_moe_norm.size() - 1; +- } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, false)) { +- ctx->num_additional_fused_ops = topk_moe.size() - 1; ++ } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 9 }) && ++ ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) && ++ ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) { ++ ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1; ++ } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) && ++ ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) && ++ ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) { ++ ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1; ++ } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) && ++ ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) && ++ ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) { ++ ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1; + } + } + ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false); +@@ -12618,10 +12671,18 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg + ctx->num_additional_fused_ops = num_adds - 1; + } else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { + ctx->num_additional_fused_ops = 1; +- } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, true)) { +- ctx->num_additional_fused_ops = topk_moe_norm.size() - 1; +- } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, false)) { +- ctx->num_additional_fused_ops = topk_moe.size() - 1; ++ } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 9 }) && ++ ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) && ++ ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) { ++ ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1; ++ } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) && ++ ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) && ++ ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) { ++ ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1; ++ } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) && ++ ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) && ++ ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) { ++ ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1; + } + } + +@@ -12754,25 +12815,44 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * + while (first_unused < graph->n_nodes) { + std::vector current_set; + +- // Avoid reordering topk_moe_norm +- if (first_unused + (int)topk_moe_norm.size() <= graph->n_nodes) { +- bool is_topk_moe_norm = true; +- for (size_t j = 0; j < topk_moe_norm.size(); ++j) { +- if (graph->nodes[first_unused + j]->op != topk_moe_norm[j] || used[first_unused + j]) { +- is_topk_moe_norm = false; ++ // Check for fusion patterns and avoid reordering them ++ auto const &match_pattern = [&](const std::initializer_list &pattern, int start) -> bool { ++ if (start + (int)pattern.size() <= graph->n_nodes) { ++ bool is_pattern = true; ++ for (size_t j = 0; j < pattern.size(); ++j) { ++ if (graph->nodes[start + j]->op != pattern.begin()[j] || used[start + j]) { ++ is_pattern = false; ++ } + } ++ return is_pattern; + } +- if (is_topk_moe_norm) { +- for (size_t j = 0; j < topk_moe_norm.size(); ++j) { ++ return false; ++ }; ++ ++ auto const &keep_pattern = [&](const std::initializer_list &pattern) -> bool { ++ if (match_pattern(pattern, first_unused)) { ++ for (size_t j = 0; j < pattern.size(); ++j) { + new_order.push_back(graph->nodes[first_unused + j]); + used[first_unused + j] = true; + } + while (first_unused < graph->n_nodes && used[first_unused]) { + first_unused++; + } +- continue; ++ return true; + } ++ return false; ++ }; ++ ++ if (keep_pattern(topk_moe_early_softmax_norm)) { ++ continue; ++ } ++ if (keep_pattern(topk_moe_early_softmax)) { ++ continue; + } ++ if (keep_pattern(topk_moe_late_softmax)) { ++ continue; ++ } ++ + // First, grab the next unused node. + current_set.push_back(first_unused); + +@@ -12790,6 +12870,12 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * + if (is_empty(graph->nodes[j])) { + continue; + } ++ // Don't pull forward nodes from fusion patterns ++ if (match_pattern(topk_moe_early_softmax_norm, j) || ++ match_pattern(topk_moe_early_softmax, j) || ++ match_pattern(topk_moe_late_softmax, j)) { ++ continue; ++ } + bool ok = true; + for (int c = first_unused; c < j; ++c) { + if (!used[c] && +diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp b/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +index 9e56d5f8a..bc1c278bf 100644 +--- a/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp ++++ b/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +@@ -11,6 +11,8 @@ layout (push_constant) uniform parameter + { + uint n_rows; + uint n_expert_used; ++ float clamp_min; ++ float clamp_max; + }; + + layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in; +@@ -18,6 +20,7 @@ layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in; + layout(constant_id = 0) const uint WARP_SIZE = 32; + layout(constant_id = 1) const uint n_experts = 512; + layout(constant_id = 2) const bool with_norm = true; ++layout(constant_id = 3) const bool late_softmax = false; + + const uint experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1; + +@@ -25,53 +28,72 @@ layout (binding = 0, std430) readonly buffer Logits {float logits[];}; + layout (binding = 1, std430) writeonly buffer Weights {float weights[];}; + layout (binding = 2, std430) writeonly buffer Ids {uint ids[];}; + +-void main() { +- const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_LocalInvocationID.y; +- if (row >= n_rows) { +- return; +- } ++const float INFINITY = 1.0 / 0.0; + +- const uint logits_offset = n_experts * row; +- const uint weights_offset = n_expert_used * row; +- const uint ids_offset = n_experts * row; +- +- float logits_r[experts_per_thread]; +- +- const float INFINITY = 1.0 / 0.0; ++// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path. ++void softmax_warp_inplace(inout float vals[experts_per_thread], const uint limit, const uint lane, const bool use_limit) { ++ float max_val = -INFINITY; + + [[unroll]] +- for (uint i = 0; i < n_experts; i += WARP_SIZE) { +- const uint expert = i + gl_LocalInvocationID.x; +- logits_r[i / WARP_SIZE] = n_experts % WARP_SIZE == 0 || expert < n_experts ? logits[logits_offset + expert] : -INFINITY; ++ for (int i = 0; i < experts_per_thread; i++) { ++ const uint idx = lane + i * WARP_SIZE; ++ const bool is_active = !use_limit || (idx < limit); ++ if (is_active) { ++ max_val = max(max_val, vals[i]); ++ } + } + +- float max_val = logits_r[0]; ++ max_val = subgroupMax(max_val); ++ ++ float sum = 0.f; + + [[unroll]] +- for (int i = 1; i < experts_per_thread; i++) { +- const float val = logits_r[i]; +- max_val = max(val, max_val); ++ for (int i = 0; i < experts_per_thread; i++) { ++ const uint idx = lane + i * WARP_SIZE; ++ const bool is_active = !use_limit || (idx < limit); ++ if (is_active) { ++ const float val = exp(vals[i] - max_val); ++ vals[i] = val; ++ sum += val; ++ } else { ++ vals[i] = 0.f; ++ } + } + +- max_val = subgroupMax(max_val); ++ sum = subgroupAdd(sum); + +- float wt[experts_per_thread]; +- float tmp = 0.f; ++ const float inv_sum = 1.0f / sum; + + [[unroll]] + for (int i = 0; i < experts_per_thread; i++) { +- const float val = logits_r[i]; +- wt[i] = exp(val - max_val); +- tmp += wt[i]; ++ const uint idx = lane + i * WARP_SIZE; ++ const bool is_active = !use_limit || (idx < limit); ++ if (is_active) { ++ vals[i] *= inv_sum; ++ } + } ++} + +- tmp = subgroupAdd(tmp); ++void main() { ++ const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_LocalInvocationID.y; ++ if (row >= n_rows) { ++ return; ++ } + +- const float inv_sum = 1.0f / tmp; ++ const uint logits_offset = n_experts * row; ++ const uint weights_offset = n_expert_used * row; ++ const uint ids_offset = n_experts * row; ++ ++ float wt[experts_per_thread]; + + [[unroll]] +- for (int i = 0; i < experts_per_thread; i++) { +- wt[i] = wt[i] * inv_sum; ++ for (uint i = 0; i < n_experts; i += WARP_SIZE) { ++ const uint expert = i + gl_LocalInvocationID.x; ++ wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[logits_offset + expert] : -INFINITY; ++ } ++ ++ if (!late_softmax) { ++ softmax_warp_inplace(wt, n_experts, gl_LocalInvocationID.x, false); + } + + // at this point, each thread holds a portion of softmax, +@@ -82,6 +104,11 @@ void main() { + + float output_weights[experts_per_thread]; + ++ [[unroll]] ++ for (int i = 0; i < experts_per_thread; i++) { ++ output_weights[i] = 0.f; ++ } ++ + for (int k = 0; k < n_expert_used; k++) { + float max_val = wt[0]; + uint max_expert = gl_LocalInvocationID.x; +@@ -121,6 +148,7 @@ void main() { + + if (with_norm) { + wt_sum = subgroupAdd(wt_sum); ++ wt_sum = clamp(wt_sum, clamp_min, clamp_max); + const float inv_sum = 1.0f / wt_sum; + + [[unroll]] +@@ -129,6 +157,10 @@ void main() { + } + } + ++ if (late_softmax) { ++ softmax_warp_inplace(output_weights, n_expert_used, gl_LocalInvocationID.x, true); ++ } ++ + [[unroll]] + for (uint i = 0; i < experts_per_thread; ++i) { + uint idx = i * WARP_SIZE + gl_LocalInvocationID.x; diff --git a/llama/patches/0032-vulkan-Fuse-rope-set_rows-16769.patch b/llama/patches/0032-vulkan-Fuse-rope-set_rows-16769.patch new file mode 100644 index 000000000..64c7ffa42 --- /dev/null +++ b/llama/patches/0032-vulkan-Fuse-rope-set_rows-16769.patch @@ -0,0 +1,1242 @@ +From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 +From: Jeff Bolz +Date: Wed, 29 Oct 2025 15:13:10 -0500 +Subject: [PATCH] vulkan: Fuse rope+set_rows (#16769) + +This pattern appears in a lot of models, the rope operation is applied right +before storing into the KV cache (usually on the K tensor). + +Add a path to some of the rope shaders that computes the destination address +based on the set_rows tensor. Compile variants of the shader with D_TYPE of +f16 (the usual KV cache type). + +Add a src3 operand to ggml_vk_op_f32 - sometimes rope uses three srcs and needs +the fourth for the row indices. + +Add fused_ops_write_mask to indicate which intermediate tensors need to write +their results to memory. Skipping writing the roped K value helps to allow more +nodes to run concurrently. + +Add logic to ggml_vk_graph_optimize to make ROPE+VIEW+SET_ROWS consecutive. It +rarely starts out that way in the graph. + +Add new backend tests. +--- + ggml/src/ggml-vulkan/ggml-vulkan.cpp | 334 +++++++++++++----- + .../ggml-vulkan/vulkan-shaders/rope_head.glsl | 2 + + .../ggml-vulkan/vulkan-shaders/rope_neox.comp | 13 +- + .../ggml-vulkan/vulkan-shaders/rope_norm.comp | 13 +- + .../vulkan-shaders/vulkan-shaders-gen.cpp | 4 + + tests/test-backend-ops.cpp | 122 +++++-- + 6 files changed, 371 insertions(+), 117 deletions(-) + +diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp +index b2855b078..aaf4334b5 100644 +--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp ++++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp +@@ -458,6 +458,11 @@ static topk_moe_mode ggml_vk_num_additional_ops_to_topk_moe_mode(uint32_t num) { + return mode; + } + ++static constexpr std::initializer_list> rope_view_set_rows_edges { ++ { 1, 0, 0 }, // view->src[0] == rope ++ { 2, 0, 1 }, // set_rows->src[0] == view ++}; ++ + struct vk_device_struct { + std::recursive_mutex mutex; + +@@ -640,8 +645,8 @@ struct vk_device_struct { + vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16; + vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512; + vk_pipeline pipeline_soft_max_back_f32; +- vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16; +- vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16; ++ vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16, pipeline_rope_norm_f32_f16; ++ vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16, pipeline_rope_neox_f32_f16; + vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16; + vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16; + vk_pipeline pipeline_argsort_f32[num_argsort_pipelines]; +@@ -1054,6 +1059,7 @@ struct vk_op_rope_push_constants { + uint32_t s2; + int32_t sections[4]; + uint32_t is_back; ++ uint32_t set_rows_stride; + }; + + struct vk_op_soft_max_push_constants { +@@ -1563,6 +1569,10 @@ struct ggml_backend_vk_context { + // number of additional consecutive nodes that are being fused with the + // node currently being processed + int num_additional_fused_ops {}; ++ // Bitmask of which fused ops need to write an intermediate value to memory. ++ // Bit 'i' means nodes[start_of_fusion + i] writes to memory. ++ // If there's no fusion, bit 0 is still set. ++ int fused_ops_write_mask {}; + }; + + static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT +@@ -3697,21 +3707,27 @@ static void ggml_vk_load_shaders(vk_device& device) { + ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32", soft_max_back_f32_len, soft_max_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1, true); + +- ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); +- ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); +- ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32, "rope_multi_f32", rope_multi_f32_len, rope_multi_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); +- ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f32, "rope_vision_f32", rope_vision_f32_len, rope_vision_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32, "rope_multi_f32", rope_multi_f32_len, rope_multi_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f32, "rope_vision_f32", rope_vision_f32_len, rope_vision_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + + if (device->float_controls_rte_fp16) { +- ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_rte_len, rope_norm_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); +- ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_rte_len, rope_neox_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); +- ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_rte_len, rope_multi_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); +- ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_rte_len, rope_vision_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_rte_len, rope_norm_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_rte_len, rope_neox_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_rte_len, rope_multi_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_rte_len, rope_vision_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); ++ ++ ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32_f16, "rope_norm_f32_f16", rope_norm_f32_f16_rte_len, rope_norm_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16", rope_neox_f32_f16_rte_len, rope_neox_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + } else { +- ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); +- ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); +- ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_len, rope_multi_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); +- ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_len, rope_vision_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_len, rope_multi_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_len, rope_vision_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); ++ ++ ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32_f16, "rope_norm_f32_f16", rope_norm_f32_f16_len, rope_norm_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16", rope_neox_f32_f16_len, rope_neox_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + } + + for (uint32_t i = 0; i < num_argsort_pipelines; ++i) { +@@ -8170,7 +8186,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const + case GGML_OP_ROPE: + case GGML_OP_ROPE_BACK: + { +- const int mode = ((const int32_t *) dst->op_params)[2]; ++ const ggml_tensor *rope = ctx->num_additional_fused_ops == 2 ? dst->src[0]->src[0] : dst; ++ const int mode = ((const int32_t *) rope->op_params)[2]; + const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; + const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; + const bool is_vision = mode == GGML_ROPE_TYPE_VISION; +@@ -8179,6 +8196,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_rope_neox_f32; + } ++ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) { ++ return ctx->device->pipeline_rope_neox_f32_f16; ++ } + if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { + return ctx->device->pipeline_rope_neox_f16; + } +@@ -8200,6 +8220,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_rope_norm_f32; + } ++ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) { ++ return ctx->device->pipeline_rope_norm_f32_f16; ++ } + if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { + return ctx->device->pipeline_rope_norm_f16; + } +@@ -8409,20 +8432,22 @@ static uint32_t get_misalign_bytes(ggml_backend_vk_context * ctx, const ggml_ten + return ((vk_tensor_offset(t) + t->view_offs) & (ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1));; + } + +-template void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, T &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { ++template void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, T &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) { + GGML_UNUSED(p); + GGML_UNUSED(src0); + GGML_UNUSED(src1); + GGML_UNUSED(src2); ++ GGML_UNUSED(src3); + GGML_UNUSED(dst); + static_assert(!std::is_const::value, "unexpected type"); + GGML_ASSERT(!src0 || get_misalign_bytes(ctx, src0) == 0); + GGML_ASSERT(!src1 || get_misalign_bytes(ctx, src1) == 0); + GGML_ASSERT(!src2 || get_misalign_bytes(ctx, src2) == 0); ++ GGML_ASSERT(!src3 || get_misalign_bytes(ctx, src3) == 0); + GGML_ASSERT(!dst || get_misalign_bytes(ctx, dst) == 0); + } + +-template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_unary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { ++template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_unary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) { + const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); + const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); + +@@ -8430,9 +8455,10 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk + + GGML_UNUSED(src1); + GGML_UNUSED(src2); ++ GGML_UNUSED(src3); + } + +-template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_sum_rows_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { ++template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_sum_rows_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) { + const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); + const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); + +@@ -8440,9 +8466,10 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk + + GGML_UNUSED(src1); + GGML_UNUSED(src2); ++ GGML_UNUSED(src3); + } + +-template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_pad_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { ++template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_pad_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) { + const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); + const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); + +@@ -8450,9 +8477,10 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk + + GGML_UNUSED(src1); + GGML_UNUSED(src2); ++ GGML_UNUSED(src3); + } + +-template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_im2col_3d_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { ++template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_im2col_3d_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) { + const uint32_t a_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type); + const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); + +@@ -8460,9 +8488,10 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk + + GGML_UNUSED(src0); + GGML_UNUSED(src2); ++ GGML_UNUSED(src3); + } + +-template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_binary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { ++template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_binary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) { + const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); + const uint32_t b_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type); + const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); +@@ -8472,9 +8501,10 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk + p.misalign_offsets = (a_offset << 16) | (b_offset << 8) | d_offset; + + GGML_UNUSED(src2); ++ GGML_UNUSED(src3); + } + +-template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_upscale_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { ++template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_upscale_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) { + const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); + const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); + +@@ -8483,10 +8513,11 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk + + GGML_UNUSED(src1); + GGML_UNUSED(src2); ++ GGML_UNUSED(src3); + } + + template +-static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op, PC&& pc, bool dryrun = false) { ++static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst, ggml_op op, PC&& pc, bool dryrun = false) { + VK_LOG_DEBUG("ggml_vk_op_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; + if (src1 != nullptr) { + std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; +@@ -8494,6 +8525,9 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co + if (src2 != nullptr) { + std::cerr << "), (" << src2 << ", name=" << src2->name << ", type=" << src2->type << ", ne0=" << src2->ne[0] << ", ne1=" << src2->ne[1] << ", ne2=" << src2->ne[2] << ", ne3=" << src2->ne[3] << ", nb0=" << src2->nb[0] << ", nb1=" << src2->nb[1] << ", nb2=" << src2->nb[2] << ", nb3=" << src2->nb[3]; + } ++ if (src3 != nullptr) { ++ std::cerr << "), (" << src3 << ", name=" << src3->name << ", type=" << src3->type << ", ne0=" << src3->ne[0] << ", ne1=" << src3->ne[1] << ", ne2=" << src3->ne[2] << ", ne3=" << src3->ne[3] << ", nb0=" << src3->nb[0] << ", nb1=" << src3->nb[1] << ", nb2=" << src3->nb[2] << ", nb3=" << src3->nb[3]; ++ } + std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; + std::cerr << "), " << ggml_op_name(op) << ", " << (dryrun ? "dryrun" : "") << ")"); + GGML_ASSERT(op == GGML_OP_GET_ROWS || op == GGML_OP_CPY || (!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type)))); // NOLINT +@@ -8520,6 +8554,13 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co + const uint64_t ne23 = use_src2 ? src2->ne[3] : 0; + const uint64_t ne2 = ne20 * ne21; + ++ const bool use_src3 = src3 != nullptr; ++ const uint64_t ne30 = use_src3 ? src3->ne[0] : 0; ++ const uint64_t ne31 = use_src3 ? src3->ne[1] : 0; ++ const uint64_t ne32 = use_src3 ? src3->ne[2] : 0; ++ const uint64_t ne33 = use_src3 ? src3->ne[3] : 0; ++ const uint64_t ne3 = ne30 * ne31; ++ + const uint64_t ned0 = dst->ne[0]; + const uint64_t ned1 = dst->ne[1]; + const uint64_t ned2 = dst->ne[2]; +@@ -8550,6 +8591,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co + ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; + ggml_backend_vk_buffer_context * src1_buf_ctx = use_src1 ? (ggml_backend_vk_buffer_context *)src1->buffer->context : nullptr; + ggml_backend_vk_buffer_context * src2_buf_ctx = use_src2 ? (ggml_backend_vk_buffer_context *)src2->buffer->context : nullptr; ++ ggml_backend_vk_buffer_context * src3_buf_ctx = use_src3 ? (ggml_backend_vk_buffer_context *)src3->buffer->context : nullptr; + + vk_buffer d_X = nullptr; + size_t x_buf_offset = 0; +@@ -8557,10 +8599,13 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co + size_t y_buf_offset = 0; + vk_buffer d_Z = nullptr; + size_t z_buf_offset = 0; ++ vk_buffer d_W = nullptr; ++ size_t w_buf_offset = 0; + + bool src0_uma = false; + bool src1_uma = false; + bool src2_uma = false; ++ bool src3_uma = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, src0->data, d_X, x_buf_offset); +@@ -8573,6 +8618,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co + ggml_vk_host_get(ctx->device, src2->data, d_Z, z_buf_offset); + src2_uma = d_Z != nullptr; + } ++ if (use_src3) { ++ ggml_vk_host_get(ctx->device, src3->data, d_W, w_buf_offset); ++ src3_uma = d_W != nullptr; ++ } + } + + vk_buffer d_D = dst_buf_ctx->dev_buffer; +@@ -8594,11 +8643,17 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co + z_buf_offset = vk_tensor_offset(src2) + src2->view_offs; + GGML_ASSERT(d_Z != nullptr); + } ++ if (use_src3 && !src3_uma) { ++ d_W = src3_buf_ctx->dev_buffer; ++ w_buf_offset = vk_tensor_offset(src3) + src3->view_offs; ++ GGML_ASSERT(d_W != nullptr); ++ } + // Compute misalignment offset for descriptors and store it in in push constants, then align the descriptor offsets. +- init_pushconst_tensor_offsets(ctx, pc, src0, src1, src2, dst); ++ init_pushconst_tensor_offsets(ctx, pc, src0, src1, src2, src3, dst); + x_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1); + y_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1); + z_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1); ++ w_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1); + d_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1); + + std::array elements; +@@ -8799,12 +8854,13 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co + break; + } + +- uint64_t x_sz, y_sz, z_sz, d_sz; ++ uint64_t x_sz, y_sz, z_sz, w_sz, d_sz; + + if (op_supports_incontiguous) { + x_sz = ggml_nbytes(src0) + get_misalign_bytes(ctx, src0); + y_sz = use_src1 ? ggml_nbytes(src1) + get_misalign_bytes(ctx, src1) : 0; + z_sz = use_src2 ? ggml_nbytes(src2) + get_misalign_bytes(ctx, src2) : 0; ++ w_sz = use_src3 ? ggml_nbytes(src3) + get_misalign_bytes(ctx, src3) : 0; + d_sz = ggml_nbytes(dst) + get_misalign_bytes(ctx, dst); + + if (x_buf_offset + x_sz >= d_X->size) { +@@ -8816,6 +8872,9 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co + if (use_src2 && z_buf_offset + z_sz >= d_Z->size) { + z_sz = ggml_vk_get_max_buffer_range(ctx, d_Z, z_buf_offset); + } ++ if (use_src3 && w_buf_offset + w_sz >= d_W->size) { ++ w_sz = ggml_vk_get_max_buffer_range(ctx, d_W, w_buf_offset); ++ } + if (d_buf_offset + d_sz >= d_D->size) { + d_sz = ggml_vk_get_max_buffer_range(ctx, d_D, d_buf_offset); + } +@@ -8823,6 +8882,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co + x_sz = ggml_type_size(src0->type)/ggml_blck_size(src0->type) * ne0 * ne02 * ne03; + y_sz = use_src1 ? ggml_type_size(src1->type) * ne1 * ne12 * ne13 : 0; + z_sz = use_src2 ? ggml_type_size(src2->type) * ne2 * ne22 * ne23 : 0; ++ w_sz = use_src3 ? ggml_type_size(src3->type) * ne3 * ne32 * ne33 : 0; + d_sz = ggml_type_size(dst->type) * ned * ned2 * ned3; + } + +@@ -8864,14 +8924,19 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); + } else if (op == GGML_OP_ROPE || op == GGML_OP_ROPE_BACK) { + // Empty src2 is possible in rope, but the shader needs a buffer +- vk_subbuffer subbuf_z; ++ vk_subbuffer subbuf_z, subbuf_w; + if (use_src2) { + subbuf_z = { d_Z, z_buf_offset, z_sz }; + } else { + subbuf_z = { d_X, 0, x_sz }; + } ++ if (use_src3) { ++ subbuf_w = { d_W, w_buf_offset, w_sz }; ++ } else { ++ subbuf_w = { d_X, 0, x_sz }; ++ } + +- ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); ++ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz }, subbuf_w }, pc, elements); + } else if (op == GGML_OP_IM2COL || op == GGML_OP_IM2COL_3D) { + if (ctx->device->shader_int64 && ctx->device->buffer_device_address) { + // buffer device address path doesn't use dst buffer +@@ -8887,6 +8952,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co + } else if (op == GGML_OP_OPT_STEP_SGD) { + // OPT_STEP_SGD works on src0, it does not need dst + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz } }, pc, elements); ++ } else if (use_src3) { ++ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_W, w_buf_offset, w_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); + } else if (use_src2) { + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); + } else if (use_src1) { +@@ -8901,7 +8968,7 @@ static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context& subctx, + const uint32_t src1_type_size = ggml_type_size(src1->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + +- ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GET_ROWS, { ++ ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_GET_ROWS, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, +@@ -8921,7 +8988,7 @@ static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const + // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused + int offset = dst->op_params[3] / 4; // offset in bytes + +- ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_ACC, { ++ ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_ACC, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, +@@ -9046,7 +9113,7 @@ static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const + const uint32_t src1_type_size = ggml_type_size(src1->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + +- ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_ADD, { ++ ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_ADD, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, +@@ -9061,7 +9128,7 @@ static void ggml_vk_sub(ggml_backend_vk_context * ctx, vk_context& subctx, const + const uint32_t src1_type_size = ggml_type_size(src1->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + +- ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SUB, { ++ ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SUB, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, +@@ -9076,7 +9143,7 @@ static void ggml_vk_mul(ggml_backend_vk_context * ctx, vk_context& subctx, const + const uint32_t src1_type_size = ggml_type_size(src1->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + +- ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_MUL, { ++ ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_MUL, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, +@@ -9091,7 +9158,7 @@ static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const + const uint32_t src1_type_size = ggml_type_size(src1->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + +- ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_DIV, { ++ ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_DIV, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, +@@ -9106,7 +9173,7 @@ static void ggml_vk_add_id(ggml_backend_vk_context * ctx, vk_context& subctx, co + const uint32_t src1_type_size = ggml_type_size(src1->type); + const uint32_t src2_type_size = ggml_type_size(src2->type); + +- ggml_vk_op_f32(ctx, subctx, src0, src1, src2, dst, GGML_OP_ADD_ID, { ++ ggml_vk_op_f32(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_ADD_ID, { + (uint32_t)dst->ne[0], + (uint32_t)dst->ne[1], + (uint32_t)src0->nb[1] / src0_type_size, +@@ -9339,7 +9406,7 @@ static void ggml_vk_ssm_conv(ggml_backend_vk_context * ctx, vk_context& subctx, + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + +- ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SSM_CONV, { ++ ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SSM_CONV, { + (uint32_t)src0->nb[1], (uint32_t)src0->nb[2], + (uint32_t)src1->nb[1], + (uint32_t)dst->nb[0], (uint32_t)dst->nb[1], (uint32_t)dst->nb[2], +@@ -9457,7 +9524,7 @@ static void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& su + static void ggml_vk_opt_step_sgd(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) { + const size_t n = ggml_nelements(dst->src[0]); + +- ggml_vk_op_f32(ctx, subctx, src0, src1, src2, dst, GGML_OP_OPT_STEP_SGD, { (uint32_t)n, 0, 0.0f, 0.0f }, dryrun); ++ ggml_vk_op_f32(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_OPT_STEP_SGD, { (uint32_t)n, 0, 0.0f, 0.0f }, dryrun); + } + + static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { +@@ -9467,7 +9534,7 @@ static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, co + const uint32_t src1_type_size = ggml_type_size(src1->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + +- ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONCAT, { ++ ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_CONCAT, { + (uint32_t)ggml_nelements(dst), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, +@@ -9491,7 +9558,7 @@ static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, c + sf1 = (float)(dst->ne[1] - 1) / (src0->ne[1] - 1); + } + +- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UPSCALE, { ++ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_UPSCALE, { + (uint32_t)ggml_nelements(dst), 0, 0, + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], + (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, +@@ -9505,23 +9572,23 @@ static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context& subctx, con + p.param1 = ggml_get_op_params_f32(dst, 0); + p.param2 = ggml_get_op_params_f32(dst, 1); + +- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SCALE, std::move(p), dryrun); ++ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SCALE, std::move(p), dryrun); + } + + static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { +- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQR, vk_op_unary_push_constants_init(src0, dst), dryrun); ++ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SQR, vk_op_unary_push_constants_init(src0, dst), dryrun); + } + + static void ggml_vk_sqrt(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { +- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQRT, vk_op_unary_push_constants_init(src0, dst), dryrun); ++ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SQRT, vk_op_unary_push_constants_init(src0, dst), dryrun); + } + + static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { +- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SIN, vk_op_unary_push_constants_init(src0, dst), dryrun); ++ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SIN, vk_op_unary_push_constants_init(src0, dst), dryrun); + } + + static void ggml_vk_cos(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { +- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_COS, vk_op_unary_push_constants_init(src0, dst), dryrun); ++ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_COS, vk_op_unary_push_constants_init(src0, dst), dryrun); + } + + static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { +@@ -9529,12 +9596,12 @@ static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, con + p.param1 = ggml_get_op_params_f32(dst, 0); + p.param2 = ggml_get_op_params_f32(dst, 1); + +- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CLAMP, std::move(p), dryrun); ++ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_CLAMP, std::move(p), dryrun); + } + + static void ggml_vk_pad(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + vk_op_pad_push_constants p = vk_op_pad_push_constants_init(src0, dst); +- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_PAD, std::move(p), dryrun); ++ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_PAD, std::move(p), dryrun); + } + + static void ggml_vk_roll(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { +@@ -9549,17 +9616,17 @@ static void ggml_vk_roll(ggml_backend_vk_context * ctx, vk_context& subctx, cons + memcpy(&p.param1, &s01_packed, sizeof(float)); + memcpy(&p.param2, &s23_packed, sizeof(float)); + +- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ROLL, std::move(p), dryrun); ++ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ROLL, std::move(p), dryrun); + } + + static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst)); +- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT, std::move(p), dryrun); ++ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_REPEAT, std::move(p), dryrun); + } + + static void ggml_vk_repeat_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst)); +- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT_BACK, std::move(p), dryrun); ++ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_REPEAT_BACK, std::move(p), dryrun); + } + + static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { +@@ -9575,7 +9642,7 @@ static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const + } + + vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ne); +- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, std::move(p), dryrun); ++ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_CPY, std::move(p), dryrun); + } + + static void ggml_vk_set_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { +@@ -9590,7 +9657,7 @@ static void ggml_vk_set_rows(ggml_backend_vk_context * ctx, vk_context& subctx, + return; + } + +- ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SET_ROWS, { ++ ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SET_ROWS, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, +@@ -9601,13 +9668,13 @@ static void ggml_vk_set_rows(ggml_backend_vk_context * ctx, vk_context& subctx, + } + + static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { +- ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SILU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun); ++ ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SILU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun); + } + + static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + float * op_params = (float *)dst->op_params; + +- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun); ++ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun); + } + + static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { +@@ -9618,7 +9685,7 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx + const float eps = float_op_params[1]; + const uint32_t group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups); + +- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun); ++ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun); + } + + static uint32_t ggml_vk_rms_num_partials(ggml_backend_vk_context * ctx, const ggml_tensor *node) { +@@ -9641,7 +9708,7 @@ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, + + uint32_t param3 = ctx->do_add_rms_partials ? ggml_vk_rms_num_partials(ctx, dst) : 0; + +- ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM, { ++ ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_RMS_NORM, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, +@@ -9658,16 +9725,16 @@ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, + + static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + float * op_params = (float *)dst->op_params; +- ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun); ++ ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun); + } + + static void ggml_vk_l2_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + float * op_params = (float *)dst->op_params; +- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_L2_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun); ++ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_L2_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun); + } + + static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { +- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun); ++ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun); + } + + static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { +@@ -9690,7 +9757,7 @@ static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const + + const uint32_t mode = split ? 2 : (swapped ? 1 : 0); + +- ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GLU, ++ ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_GLU, + { + (uint32_t)ggml_nelements(dst), + (uint32_t)src0->ne[0], +@@ -9703,7 +9770,7 @@ static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const + + static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + int32_t * op_params = (int32_t *)dst->op_params; +- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] }, dryrun); ++ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] }, dryrun); + } + + static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) { +@@ -9728,7 +9795,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + +- ggml_vk_op_f32(ctx, subctx, src0, src1, src2, dst, GGML_OP_SOFT_MAX, { ++ ggml_vk_op_f32(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_SOFT_MAX, { + ncols, + src1 != nullptr ? nrows_y : (uint32_t)0, + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], +@@ -9744,7 +9811,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, + + static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + float * op_params = (float *)dst->op_params; +- ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), op_params[0], op_params[1] }, dryrun); ++ ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), op_params[0], op_params[1] }, dryrun); + } + + static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx, bool dryrun = false) { +@@ -9835,7 +9902,12 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, + }, pc, elements); + } + +-static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool backprop, bool dryrun = false) { ++static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_cgraph * cgraph, int node_idx, bool backprop, bool dryrun = false) { ++ ggml_tensor * dst = cgraph->nodes[node_idx]; ++ const ggml_tensor * src0 = dst->src[0]; ++ const ggml_tensor * src1 = dst->src[1]; ++ const ggml_tensor * src2 = dst->src[2]; ++ const ggml_tensor * src3 = nullptr; + const int n_dims = ((int32_t *) dst->op_params)[1]; + const int mode = ((int32_t *) dst->op_params)[2]; + // const int n_ctx = ((int32_t *) dst->op_params)[3]; +@@ -9859,11 +9931,20 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons + uint32_t s1 = src0->nb[1] / ggml_type_size(src0->type); + uint32_t s2 = src0->nb[2] / ggml_type_size(src0->type); + +- ggml_vk_op_f32(ctx, subctx, src0, src1, src2, dst, GGML_OP_ROPE, { ++ uint32_t set_rows_stride = 0; ++ // Fused rope + view + set_rows passes the set_rows destination stride in set_rows_stride ++ // and overrides the dst and sets src3=row_indices ++ if (ctx->num_additional_fused_ops > 0) { ++ set_rows_stride = cgraph->nodes[node_idx + 2]->nb[1] / ggml_type_size(cgraph->nodes[node_idx + 2]->type); ++ src3 = cgraph->nodes[node_idx + 2]->src[1]; ++ dst = cgraph->nodes[node_idx + 2]; ++ } ++ ++ ggml_vk_op_f32(ctx, subctx, src0, src1, src2, src3, dst, GGML_OP_ROPE, { + (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1], + freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale, + src2 != nullptr, (uint32_t)src0->ne[2], s1, s2, +- { sections[0], sections[1], sections[2], sections[3] }, backprop ++ { sections[0], sections[1], sections[2], sections[3] }, backprop, set_rows_stride, + }, dryrun); + } + +@@ -9872,7 +9953,7 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c + + uint32_t ncols = src0->ne[0]; + +- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ARGSORT, { ++ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGSORT, { + ncols, + op_params[0], + }, dryrun); +@@ -9880,26 +9961,26 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c + + static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, ggml_nelements(src0)); +- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM, p, dryrun); ++ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SUM, p, dryrun); + } + + static void ggml_vk_sum_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]); +- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, p, dryrun); ++ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, p, dryrun); + } + + static void ggml_vk_mean(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]); + p.weight = 1.0f / (float)src0->ne[0]; +- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_MEAN, p, dryrun); ++ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_MEAN, p, dryrun); + } + + static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { +- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f }, dryrun); ++ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f }, dryrun); + } + + static void ggml_vk_count_equal(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { +- ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_COUNT_EQUAL, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun); ++ ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_COUNT_EQUAL, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun); + } + + static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { +@@ -9932,7 +10013,7 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co + + const vk::DeviceAddress dst_addr = d_buf->bda_addr + vk_tensor_offset(dst) + dst->view_offs; + +- ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_IM2COL, { ++ ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_IM2COL, { + dst_addr, + batch_offset, offset_delta, + IC, IW, IH, OW, OH, KW, KH, +@@ -10005,7 +10086,7 @@ static void ggml_vk_im2col_3d(ggml_backend_vk_context * ctx, vk_context& subctx, + pc.OH_OW_IC_KD_KH_KW = OH*OW*IC*KD*KH*KW; + pc.OW_IC_KD_KH_KW = OW*IC*KD*KH*KW; + +- ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_IM2COL_3D, std::move(pc), dryrun); ++ ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_IM2COL_3D, std::move(pc), dryrun); + } + + static void ggml_vk_timestep_embedding(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { +@@ -10013,7 +10094,7 @@ static void ggml_vk_timestep_embedding(ggml_backend_vk_context * ctx, vk_context + const uint32_t max_period = dst->op_params[1]; + const uint32_t nb1 = dst->nb[1] / ggml_type_size(dst->type); + +- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_TIMESTEP_EMBEDDING, { ++ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_TIMESTEP_EMBEDDING, { + nb1, dim, max_period, + }, dryrun); + } +@@ -10046,7 +10127,7 @@ static void ggml_vk_conv_transpose_1d(ggml_backend_vk_context * ctx, vk_context& + p.nb1 = static_cast(nb1 / nb0); + p.s0 = static_cast(s0); + +- ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_TRANSPOSE_1D, std::move(p), dryrun); ++ ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_CONV_TRANSPOSE_1D, std::move(p), dryrun); + } + + static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { +@@ -10069,7 +10150,7 @@ static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, c + + const uint32_t parallel_elements = N * OC * OH * OW; + +- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_POOL_2D, { ++ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_POOL_2D, { + IW, IH, OW, OH, OC, + parallel_elements, + op, +@@ -10123,7 +10204,7 @@ static void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context & subctx, + GGML_ASSERT(ne03 == ne2); + GGML_ASSERT(ne02 == ne12); + +- ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_2D, std::move(p), dryrun); ++ ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_CONV_2D, std::move(p), dryrun); + } + + static void ggml_vk_conv_transpose_2d(ggml_backend_vk_context * ctx, vk_context & subctx, const ggml_tensor * src0, +@@ -10172,7 +10253,7 @@ static void ggml_vk_conv_transpose_2d(ggml_backend_vk_context * ctx, vk_context + GGML_ASSERT(ne02 == ne2); + GGML_ASSERT(ne03 == ne12); + +- ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_TRANSPOSE_2D, std::move(p), dryrun); ++ ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_CONV_TRANSPOSE_2D, std::move(p), dryrun); + } + + static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { +@@ -10196,12 +10277,12 @@ static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx + GGML_ASSERT(src0->ne[3] == p.channels); + GGML_ASSERT(src1->ne[3] == p.batches); + +- ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_2D_DW, std::move(p), dryrun); ++ ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_CONV_2D_DW, std::move(p), dryrun); + } + + static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + const float * op_params = (const float *)dst->op_params; +- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f }, dryrun); ++ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f }, dryrun); + } + + #ifdef GGML_VULKAN_RUN_TESTS +@@ -11327,7 +11408,6 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr + case GGML_OP_DIAG_MASK_INF: + case GGML_OP_SOFT_MAX: + case GGML_OP_SOFT_MAX_BACK: +- case GGML_OP_ROPE: + case GGML_OP_ROPE_BACK: + case GGML_OP_ARGSORT: + case GGML_OP_SUM: +@@ -11401,9 +11481,12 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr + // nodes require synchronization. + for (int32_t i = 0; i < ctx->num_additional_fused_ops + 1 && !need_sync; ++i) { + const ggml_tensor *cur_node = cgraph->nodes[node_idx + i]; +- if (overlaps_unsynced(cur_node, ctx->unsynced_nodes_read) || overlaps_unsynced(cur_node, ctx->unsynced_nodes_written)) { +- need_sync = true; +- break; ++ // If the node actually writes to memory, then check if it needs to sync ++ if (ctx->fused_ops_write_mask & (1 << i)) { ++ if (overlaps_unsynced(cur_node, ctx->unsynced_nodes_read) || overlaps_unsynced(cur_node, ctx->unsynced_nodes_written)) { ++ need_sync = true; ++ break; ++ } + } + for (uint32_t j = 0; j < GGML_MAX_SRC; ++j) { + if (!cur_node->src[j]) { +@@ -11430,7 +11513,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr + for (int32_t i = 0; i < ctx->num_additional_fused_ops + 1; ++i) { + const ggml_tensor *cur_node = cgraph->nodes[node_idx + i]; + // Multiple outputs could be written, e.g. in topk_moe. Add them all to the list. +- ctx->unsynced_nodes_written.push_back(cur_node); ++ if (ctx->fused_ops_write_mask & (1 << i)) { ++ ctx->unsynced_nodes_written.push_back(cur_node); ++ } + for (uint32_t j = 0; j < GGML_MAX_SRC; ++j) { + if (!cur_node->src[j]) { + continue; +@@ -11621,11 +11706,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr + + break; + case GGML_OP_ROPE: +- ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node, false, dryrun); ++ ggml_vk_rope(ctx, compute_ctx, cgraph, node_idx, false, dryrun); + + break; + case GGML_OP_ROPE_BACK: +- ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node, true, dryrun); ++ ggml_vk_rope(ctx, compute_ctx, cgraph, node_idx, true, dryrun); + + break; + case GGML_OP_ARGSORT: +@@ -12487,6 +12572,41 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc + return true; + } + ++static bool ggml_vk_can_fuse_rope_set_rows(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, ++ int node_idx) { ++ GGML_UNUSED(ctx); ++ const ggml_tensor *rope = cgraph->nodes[node_idx + 0]; ++ const ggml_tensor *view = cgraph->nodes[node_idx + 1]; ++ const ggml_tensor *set_rows = cgraph->nodes[node_idx + 2]; ++ ++ // ne3 not tested ++ if (rope->src[0]->ne[3] != 1) { ++ return false; ++ } ++ ++ if (set_rows->type != GGML_TYPE_F32 && set_rows->type != GGML_TYPE_F16) { ++ return false; ++ } ++ ++ if (set_rows->src[1]->type != GGML_TYPE_I64) { ++ return false; ++ } ++ ++ // The view should flatten two dims of rope into one dim ++ if (!ggml_is_contiguous(view) || ++ view->ne[0] != rope->ne[0] * rope->ne[1]) { ++ return false; ++ } ++ ++ // Only norm/neox shaders have the fusion code ++ const int mode = ((const int32_t *) rope->op_params)[2]; ++ if (mode != GGML_ROPE_TYPE_NORMAL && mode != GGML_ROPE_TYPE_NEOX) { ++ return false; ++ } ++ ++ return true; ++} ++ + static uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx) { + + const ggml_tensor *first_node = cgraph->nodes[node_idx]; +@@ -12562,6 +12682,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg + ctx->num_additional_fused_ops = num_adds - 1; + } else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { + ctx->num_additional_fused_ops = 1; ++ } else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 2 }) && ++ ggml_check_edges(cgraph, i, rope_view_set_rows_edges) && ++ ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i)) { ++ ctx->num_additional_fused_ops = 2; + } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 9 }) && + ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) && + ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) { +@@ -12671,20 +12795,31 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg + ctx->num_additional_fused_ops = num_adds - 1; + } else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { + ctx->num_additional_fused_ops = 1; ++ } else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 2 }) && ++ ggml_check_edges(cgraph, i, rope_view_set_rows_edges) && ++ ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i)) { ++ ctx->num_additional_fused_ops = 2; + } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 9 }) && + ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) && + ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) { + ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1; ++ // view of argsort writes to memory ++ ctx->fused_ops_write_mask |= 1 << 3; + } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) && + ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) && + ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) { + ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1; ++ // view of argsort writes to memory ++ ctx->fused_ops_write_mask |= 1 << 3; + } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) && + ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) && + ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) { + ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1; ++ // view of argsort writes to memory ++ ctx->fused_ops_write_mask |= 1 << 1; + } + } ++ ctx->fused_ops_write_mask |= 1 << ctx->num_additional_fused_ops; + + // Signal the almost_ready fence when the graph is mostly complete (< 20% remaining) + bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5; +@@ -12730,6 +12865,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg + } + i += ctx->num_additional_fused_ops; + ctx->num_additional_fused_ops = 0; ++ ctx->fused_ops_write_mask = 0; + } + + if (vk_perf_logger_enabled) { +@@ -12887,6 +13023,32 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * + } + if (ok) { + current_set.push_back(j); ++ // Look for ROPE + VIEW + SET_ROWS and make them consecutive ++ if (graph->nodes[j]->op == GGML_OP_ROPE) { ++ int view_idx = -1; ++ int set_rows_idx = -1; ++ for (int k = j+1; k < std::min(j + 10, graph->n_nodes); ++k) { ++ if (view_idx == -1 && ++ graph->nodes[k]->op == GGML_OP_VIEW && ++ graph->nodes[k]->src[0] == graph->nodes[j]) { ++ view_idx = k; ++ continue; ++ } ++ if (view_idx != -1 && ++ set_rows_idx == -1 && ++ graph->nodes[k]->op == GGML_OP_SET_ROWS && ++ graph->nodes[k]->src[0] == graph->nodes[view_idx]) { ++ set_rows_idx = k; ++ break; ++ } ++ } ++ if (set_rows_idx != -1) { ++ current_set.push_back(view_idx); ++ current_set.push_back(set_rows_idx); ++ used[view_idx] = true; ++ used[set_rows_idx] = true; ++ } ++ } + } + } + // Second pass grabs view nodes. +diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +index 50fc1f1e2..0eda186c8 100644 +--- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl ++++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +@@ -10,6 +10,7 @@ layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; + layout (binding = 1) readonly buffer Y {int data_pos[];}; + layout (binding = 2) readonly buffer Z {float data_ff[];}; + layout (binding = 3) writeonly buffer D {D_TYPE data_d[];}; ++layout (binding = 4) readonly buffer I {uvec2 data_i[];}; // indices for set_rows + + layout (push_constant) uniform parameter { + uint ncols; +@@ -27,6 +28,7 @@ layout (push_constant) uniform parameter { + uint s2; + int sections[4]; + uint is_back; ++ uint set_rows_stride; + } p; + + float rope_yarn_ramp(const float low, const float high, const uint i0) { +diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +index 06e095bef..9f4538155 100644 +--- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp ++++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +@@ -16,12 +16,19 @@ void main() { + const uint row_x = row_dst % ne1; + const uint channel_x = row_dst / ne1; + +- const uint idst = row_dst*ne0 + i0/2; ++ uint idst = row_dst*ne0 + i0/2; + const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2; + ++ // Fusion optimization: ROPE + VIEW + SET_ROWS.. ++ // The rope output is viewed as a 1D tensor and offset based on a row index in data_i. ++ if (p.set_rows_stride != 0) { ++ idst = row_x*ne0 + i0/2; ++ idst += data_i[channel_x].x * p.set_rows_stride; ++ } ++ + if (i0 >= p.n_dims) { +- data_d[idst + i0/2 + 0] = data_a[ix + i0/2 + 0]; +- data_d[idst + i0/2 + 1] = data_a[ix + i0/2 + 1]; ++ data_d[idst + i0/2 + 0] = D_TYPE(data_a[ix + i0/2 + 0]); ++ data_d[idst + i0/2 + 1] = D_TYPE(data_a[ix + i0/2 + 1]); + + return; + } +diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +index 6ba957540..f4209ed95 100644 +--- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp ++++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +@@ -16,12 +16,19 @@ void main() { + const uint row_x = row_dst % ne1; + const uint channel_x = row_dst / ne1; + +- const uint idst = row_dst*ne0 + i0; ++ uint idst = row_dst*ne0 + i0; + const uint ix = channel_x*p.s2 + row_x*p.s1 + i0; + ++ // Fusion optimization: ROPE + VIEW + SET_ROWS.. ++ // The rope output is viewed as a 1D tensor and offset based on a row index in data_i. ++ if (p.set_rows_stride != 0) { ++ idst = row_x*ne0 + i0; ++ idst += data_i[channel_x].x * p.set_rows_stride; ++ } ++ + if (i0 >= p.n_dims) { +- data_d[idst + 0] = data_a[ix + 0]; +- data_d[idst + 1] = data_a[ix + 1]; ++ data_d[idst + 0] = D_TYPE(data_a[ix + 0]); ++ data_d[idst + 1] = D_TYPE(data_a[ix + 1]); + + return; + } +diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +index 03fa01639..e6ec589fb 100644 +--- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp ++++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +@@ -842,10 +842,14 @@ void process_shaders() { + string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("rope_norm_f16_rte", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}); ++ string_to_spv("rope_norm_f32_f16", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}); ++ string_to_spv("rope_norm_f32_f16_rte", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}); + + string_to_spv("rope_neox_f32", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("rope_neox_f16_rte", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}); ++ string_to_spv("rope_neox_f32_f16", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}); ++ string_to_spv("rope_neox_f32_f16_rte", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}); + + string_to_spv("rope_multi_f32", "rope_multi.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("rope_multi_f16", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); +diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp +index 9eb2b6687..657b6cc2f 100644 +--- a/tests/test-backend-ops.cpp ++++ b/tests/test-backend-ops.cpp +@@ -2105,6 +2105,34 @@ struct test_get_rows_back : public test_case { + } + }; + ++static void init_set_rows_row_ids(ggml_tensor * t, int num_rows) { ++ std::random_device rd; ++ std::default_random_engine rng(rd()); ++ for (int i2 = 0; i2 < t->ne[2]; i2++) { ++ for (int i1 = 0; i1 < t->ne[1]; i1++) { ++ // generate a shuffled subset of row indices ++ std::vector data(num_rows); ++ for (int i = 0; i < num_rows; i++) { ++ data[i] = i; ++ } ++ std::shuffle(data.begin(), data.end(), rng); ++ data.resize(t->ne[0]); ++ ++ const size_t offs = i1*t->nb[1] + i2*t->nb[2]; ++ if (t->type == GGML_TYPE_I32) { ++ // TODO: Make a template or something ++ std::vector data_i32(t->ne[0]); ++ for (int i = 0; i < t->ne[0]; i++) { ++ data_i32[i] = static_cast(data[i]); ++ } ++ ggml_backend_tensor_set(t, data_i32.data(), offs, t->ne[0]*sizeof(int32_t)); ++ } else { ++ ggml_backend_tensor_set(t, data.data(), offs, t->ne[0]*sizeof(int64_t)); ++ } ++ } ++ } ++} ++ + // GGML_OP_SET_ROWS + struct test_set_rows : public test_case { + const ggml_type type; +@@ -2148,37 +2176,13 @@ struct test_set_rows : public test_case { + } + + void initialize_tensors(ggml_context * ctx) override { +- std::random_device rd; +- std::default_random_engine rng(rd()); + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + if (t->type == GGML_TYPE_I64 || t->type == GGML_TYPE_I32) { + if (ggml_is_view_op(t->op)) { + continue; + } + +- for (int i2 = 0; i2 < t->ne[2]; i2++) { +- for (int i1 = 0; i1 < t->ne[1]; i1++) { +- // generate a shuffled subset of row indices +- std::vector data(ne[1]); +- for (int i = 0; i < ne[1]; i++) { +- data[i] = i; +- } +- std::shuffle(data.begin(), data.end(), rng); +- data.resize(t->ne[0]); +- +- const size_t offs = i1*t->nb[1] + i2*t->nb[2]; +- if (t->type == GGML_TYPE_I32) { +- // TODO: Make a template or something +- std::vector data_i32(t->ne[0]); +- for (int i = 0; i < t->ne[0]; i++) { +- data_i32[i] = static_cast(data[i]); +- } +- ggml_backend_tensor_set(t, data_i32.data(), offs, t->ne[0]*sizeof(int32_t)); +- } else { +- ggml_backend_tensor_set(t, data.data(), offs, t->ne[0]*sizeof(int64_t)); +- } +- } +- } ++ init_set_rows_row_ids(t, ne[1]); + } else { + init_tensor_uniform(t); + } +@@ -2207,6 +2211,67 @@ struct test_set_rows : public test_case { + } + }; + ++// GGML_OP_ROPE + GGML_OP_VIEW + GGML_OP_SET_ROWS ++struct test_rope_set_rows : public test_case { ++ const ggml_type type; ++ const ggml_type type_idx; ++ const std::array ne; ++ int mode; ++ ++ std::string vars() override { ++ return VARS_TO_STR4(type, type_idx, ne, mode); ++ } ++ ++ std::string op_desc(ggml_tensor * t) override { ++ GGML_UNUSED(t); ++ return "ROPE_SET_ROWS"; ++ } ++ ++ bool run_whole_graph() override { return true; } ++ ++ test_rope_set_rows(ggml_type type, ++ ggml_type type_idx, ++ std::array ne, ++ int mode) ++ : type(type), type_idx(type_idx), ne(ne), mode(mode) {} ++ ++ ggml_tensor * build_graph(ggml_context * ctx) override { ++ ggml_tensor * src = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, ne[0], ne[1], ne[2], 1); ++ ggml_set_name(src, "src"); ++ ++ ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne[2]); ++ ++ ggml_tensor * rope = ggml_rope(ctx, src, pos, ne[0], mode); ++ ++ ggml_tensor * view = ggml_view_2d(ctx, rope, ne[0] * ne[1], ne[2], rope->nb[2], 0); ++ ++ ggml_tensor * dst = ggml_new_tensor_4d(ctx, type, ne[0] * ne[1], ne[2] * ne[3], 1, 1); ++ ggml_set_name(dst, "dst"); ++ ++ ggml_tensor * row_idxs = ggml_new_tensor_3d(ctx, type_idx, ne[2], 1, 1); ++ ggml_set_name(row_idxs, "row_idxs"); ++ ++ ggml_tensor * out = ggml_set_rows(ctx, dst, view, row_idxs); ++ ggml_set_name(out, "out"); ++ ++ return out; ++ } ++ ++ void initialize_tensors(ggml_context * ctx) override { ++ for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { ++ if (t->type == GGML_TYPE_I64 || t->type == GGML_TYPE_I32) { ++ if (ggml_is_view_op(t->op)) { ++ continue; ++ } ++ ++ init_set_rows_row_ids(t, ne[2]); ++ } else { ++ init_tensor_uniform(t); ++ } ++ } ++ } ++}; ++ + // GGML_OP_ARGMAX + struct test_argmax : public test_case { + const ggml_type type; +@@ -6008,6 +6073,13 @@ static std::vector> make_test_cases_eval() { + } + } + ++ for (int mode : { GGML_ROPE_TYPE_NORMAL, GGML_ROPE_TYPE_NEOX }) { ++ for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) { ++ test_cases.emplace_back(new test_rope_set_rows(type, GGML_TYPE_I64, { 128, 32, 1, 100 }, mode)); ++ test_cases.emplace_back(new test_rope_set_rows(type, GGML_TYPE_I64, { 128, 32, 512, 1 }, mode)); ++ } ++ } ++ + for (ggml_type type_input : {GGML_TYPE_F32}) { + for (ggml_op_pool pool_type : {GGML_OP_POOL_AVG, GGML_OP_POOL_MAX}) { + for (int k0 : {1, 3}) { diff --git a/llama/patches/0033-vulkan-Handle-argsort-with-a-large-number-of-rows-16.patch b/llama/patches/0033-vulkan-Handle-argsort-with-a-large-number-of-rows-16.patch new file mode 100644 index 000000000..27a50a5f8 --- /dev/null +++ b/llama/patches/0033-vulkan-Handle-argsort-with-a-large-number-of-rows-16.patch @@ -0,0 +1,85 @@ +From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 +From: Jeff Bolz +Date: Thu, 30 Oct 2025 01:27:41 -0500 +Subject: [PATCH] vulkan: Handle argsort with a large number of rows (#16851) + +--- + ggml/src/ggml-vulkan/ggml-vulkan.cpp | 4 ++++ + ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp | 16 ++++++++++++---- + 2 files changed, 16 insertions(+), 4 deletions(-) + +diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp +index aaf4334b5..3604ceb04 100644 +--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp ++++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp +@@ -1084,6 +1084,7 @@ struct vk_op_soft_max_push_constants { + + struct vk_op_argsort_push_constants { + uint32_t ncols; ++ uint32_t nrows; + int32_t order; + }; + +@@ -8710,6 +8711,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co + break; + case GGML_OP_ARGSORT: + elements = { (uint32_t)ne00, (uint32_t)ggml_nrows(src0), 1 }; ++ elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]); + break; + case GGML_OP_IM2COL: + { +@@ -9952,9 +9954,11 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c + int32_t * op_params = (int32_t *)dst->op_params; + + uint32_t ncols = src0->ne[0]; ++ uint32_t nrows = ggml_nrows(src0); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGSORT, { + ncols, ++ nrows, + op_params[0], + }, dryrun); + } +diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp b/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +index c81b84452..c4e68bc02 100644 +--- a/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp ++++ b/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +@@ -14,6 +14,7 @@ layout (binding = 1) buffer D {int data_d[];}; + + layout (push_constant) uniform parameter { + uint ncols; ++ uint nrows; + uint order; + } p; + +@@ -26,10 +27,9 @@ void swap(uint idx0, uint idx1) { + dst_row[idx1] = tmp; + } + +-void argsort(bool needs_bounds_check) { ++void argsort(bool needs_bounds_check, const uint row) { + // bitonic sort + const int col = int(gl_LocalInvocationID.x); +- const uint row = gl_WorkGroupID.y; + + const uint row_offset = row * p.ncols; + +@@ -72,8 +72,16 @@ void argsort(bool needs_bounds_check) { + + void main() { + if (p.ncols == BLOCK_SIZE) { +- argsort(false); ++ uint row = gl_WorkGroupID.y; ++ while (row < p.nrows) { ++ argsort(false, row); ++ row += gl_WorkGroupSize.y * gl_NumWorkGroups.y; ++ } + } else { +- argsort(true); ++ uint row = gl_WorkGroupID.y; ++ while (row < p.nrows) { ++ argsort(true, row); ++ row += gl_WorkGroupSize.y * gl_NumWorkGroups.y; ++ } + } + } diff --git a/llama/patches/0034-vulkan-fix-shmem-overrun-in-mmq-id-shader-16873.patch b/llama/patches/0034-vulkan-fix-shmem-overrun-in-mmq-id-shader-16873.patch new file mode 100644 index 000000000..73dad676c --- /dev/null +++ b/llama/patches/0034-vulkan-fix-shmem-overrun-in-mmq-id-shader-16873.patch @@ -0,0 +1,77 @@ +From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 +From: Ruben Ortlam +Date: Fri, 31 Oct 2025 08:14:49 +0100 +Subject: [PATCH] vulkan: fix shmem overrun in mmq id shader (#16873) + +* vulkan: fix shmem overrun in mmq id shader + +* metal : fix mul_mm_id + +--------- + +Co-authored-by: Georgi Gerganov +--- + ggml/src/ggml-metal/ggml-metal-device.cpp | 2 +- + ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp | 4 ++++ + ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl | 2 +- + tests/test-backend-ops.cpp | 3 +++ + 4 files changed, 9 insertions(+), 2 deletions(-) + +diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp +index 758116342..c78082ac3 100644 +--- a/ggml/src/ggml-metal/ggml-metal-device.cpp ++++ b/ggml/src/ggml-metal/ggml-metal-device.cpp +@@ -677,7 +677,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0(ggml_metal_ + char name[256]; + + snprintf(base, 256, "kernel_mul_mm_id_map0_ne20_%d", ne20); +- snprintf(name, 256, "%s", base); ++ snprintf(name, 256, "%s_ne02=%d", base, ne02); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { +diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +index 8b238ac4b..d955b4fc7 100644 +--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp ++++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +@@ -82,9 +82,13 @@ layout (constant_id = 10) const uint WARP = 32; + + #include "mul_mmq_shmem_types.glsl" + ++#ifdef MUL_MAT_ID ++#define BK_STEP 1 ++#else + #ifndef BK_STEP + #define BK_STEP 4 + #endif ++#endif + + // Shared memory cache + shared block_a_cache buf_a[BM * BK_STEP]; +diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +index 72fec4404..1c0f5306f 100644 +--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl ++++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +@@ -27,7 +27,7 @@ struct block_a_cache { + #elif defined(DATA_A_Q8_0) + #define QUANT_R_MMQ 1 + // AMD likes 4, Intel likes 1 and Nvidia likes 2 +-#define BK_STEP 1 ++// #define BK_STEP 1 + struct block_a_cache { + int32_t qs[32/4]; + FLOAT_TYPE dm; +diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp +index 657b6cc2f..1f8dda383 100644 +--- a/tests/test-backend-ops.cpp ++++ b/tests/test-backend-ops.cpp +@@ -6722,6 +6722,9 @@ static std::vector> make_test_cases_eval() { + test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 1, 1, false, 8, 16, 1)); + test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, false, 32, 32, 32, 3)); + ++ // gpt-oss issue with Vulkan mmq_id ++ test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_MXFP4, GGML_TYPE_F32, 32, 2, false, 2880, 32, 2880)); ++ + for (ggml_type type_a : base_types) { + for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) { + for (int n_mats : {4, 8}) { diff --git a/llama/patches/0035-vulkan-Fix-crash-when-FP16-mul_mat-accumulation-is-n.patch b/llama/patches/0035-vulkan-Fix-crash-when-FP16-mul_mat-accumulation-is-n.patch new file mode 100644 index 000000000..dfa469160 --- /dev/null +++ b/llama/patches/0035-vulkan-Fix-crash-when-FP16-mul_mat-accumulation-is-n.patch @@ -0,0 +1,80 @@ +From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 +From: Masato Nakasaka +Date: Fri, 31 Oct 2025 16:18:59 +0900 +Subject: [PATCH] vulkan: Fix crash when FP16 mul_mat accumulation is not + supported (#16796) + +* Experimenting crash fix + +* added assert for aborting and fixed comment + +* changed to check if a pipeline is empty or not + +* Moved function in class definition + +* replaced with is_empty + +* Modified is_empty to check only unaligned pipelines +--- + ggml/src/ggml-vulkan/ggml-vulkan.cpp | 20 +++++++++++++------- + 1 file changed, 13 insertions(+), 7 deletions(-) + +diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp +index 3604ceb04..80185d9f0 100644 +--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp ++++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp +@@ -146,8 +146,13 @@ static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline); + struct vk_matmul_pipeline_struct { + vk_pipeline l, m, s; + vk_pipeline a_l, a_m, a_s; ++ // Returns true when all unaligned pipelines are null. ++ // We only check for unaligned variants since one of the unaligned pipelines must exist ++ // while aligned pipelines are optional ++ bool is_empty() const { ++ return l == nullptr && m == nullptr && s == nullptr; ++ } + }; +- + typedef std::shared_ptr vk_matmul_pipeline; + + struct vk_matmul_pipeline2 { +@@ -5080,7 +5085,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte + if (src1_type == GGML_TYPE_Q8_1) { + vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f32acc; + +- if (pipelines->s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr) { ++ if (pipelines->is_empty()) { + return nullptr; + } + +@@ -5229,7 +5234,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co + if (src1_type == GGML_TYPE_Q8_1) { + vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_id_q8_1[src0_type].f32acc; + +- if (pipelines->s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr) { ++ if (pipelines->is_empty()) { + return nullptr; + } + +@@ -5264,16 +5269,17 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co + return nullptr; + } + ++ vk_matmul_pipeline2& mmp = ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type]; + // XXX TODO 'prec' is not actually allowed in mul_mat_id. + bool prefer_fp16acc = ctx->device->fp16 /*&& prec == GGML_PREC_DEFAULT*/; +- bool support_fp16acc = ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f16acc != nullptr; +- bool support_fp32acc = ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f32acc != nullptr; ++ bool support_fp16acc = !mmp.f16acc->is_empty(); ++ bool support_fp32acc = !mmp.f32acc->is_empty(); + + if (support_fp16acc && (prefer_fp16acc || !support_fp32acc)) { +- return ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f16acc; ++ return mmp.f16acc; + } else { + GGML_ASSERT(support_fp32acc); +- return ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f32acc; ++ return mmp.f32acc; + } + } + diff --git a/ml/backend/ggml/ggml/src/ggml-impl.h b/ml/backend/ggml/ggml/src/ggml-impl.h index 639d551a2..e5c446d1d 100644 --- a/ml/backend/ggml/ggml/src/ggml-impl.h +++ b/ml/backend/ggml/ggml/src/ggml-impl.h @@ -693,6 +693,7 @@ GGML_API void ggml_dxgi_pdh_release(); #endif #ifdef __cplusplus +#include #include #include @@ -708,6 +709,21 @@ inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph, return ggml_can_fuse_subgraph(cgraph, start_idx, ops.size(), ops.begin(), outputs.begin(), outputs.size()); } +// Return true if the edges in the graph match expectations. +inline bool ggml_check_edges(const struct ggml_cgraph * cgraph, + int start_idx, + std::initializer_list> edges) { + for (const auto & edge : edges) { + int dst_node = edge[0]; + int src_idx = edge[1]; + int src_node = edge[2]; + if (cgraph->nodes[start_idx + dst_node]->src[src_idx] != cgraph->nodes[start_idx + src_node]) { + return false; + } + } + return true; +} + // expose GGUF internals for test code GGML_API size_t gguf_type_size(enum gguf_type type); GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params); diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.cpp b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.cpp index 758116342..c78082ac3 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -677,7 +677,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0(ggml_metal_ char name[256]; snprintf(base, 256, "kernel_mul_mm_id_map0_ne20_%d", ne20); - snprintf(name, 256, "%s", base); + snprintf(name, 256, "%s_ne02=%d", base, ne02); ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); if (res) { diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 221e29509..80185d9f0 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -146,8 +146,13 @@ static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline); struct vk_matmul_pipeline_struct { vk_pipeline l, m, s; vk_pipeline a_l, a_m, a_s; + // Returns true when all unaligned pipelines are null. + // We only check for unaligned variants since one of the unaligned pipelines must exist + // while aligned pipelines are optional + bool is_empty() const { + return l == nullptr && m == nullptr && s == nullptr; + } }; - typedef std::shared_ptr vk_matmul_pipeline; struct vk_matmul_pipeline2 { @@ -387,12 +392,81 @@ static constexpr uint32_t num_argsort_pipelines = 11; static constexpr uint32_t max_argsort_cols = 1 << (num_argsort_pipelines-1); static constexpr uint32_t num_topk_moe_pipelines = 10; -static constexpr std::array topk_moe_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, - GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE, - GGML_OP_SUM_ROWS, GGML_OP_DIV, GGML_OP_RESHAPE }; -static constexpr std::array topk_moe { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, - GGML_OP_VIEW, GGML_OP_GET_ROWS }; +static constexpr std::initializer_list topk_moe_early_softmax_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, + GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE, + GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV, + GGML_OP_RESHAPE }; +static constexpr std::initializer_list topk_moe_early_softmax { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, + GGML_OP_VIEW, GGML_OP_GET_ROWS }; +static constexpr std::initializer_list topk_moe_late_softmax { GGML_OP_ARGSORT, GGML_OP_VIEW, + GGML_OP_GET_ROWS, GGML_OP_RESHAPE, + GGML_OP_SOFT_MAX, GGML_OP_RESHAPE }; +//node #978 ( SOFT_MAX): ffn_moe_probs-15 ( 0K) [Vulka ] use=2: ffn_moe_logits-15 ( 0K) [Vulka ] +//node #979 ( RESHAPE): ffn_moe_probs-15 (re ( 0K) [Vulka ] use=1: ffn_moe_probs-15 ( 0K) [Vulka ] +//node #980 ( ARGSORT): ffn_moe_argsort-15 ( 0K) [Vulka ] use=1: ffn_moe_probs-15 ( 0K) [Vulka ] +//node #981 ( VIEW): ffn_moe_topk-15 ( 0K) [Vulka ] use=4: ffn_moe_argsort-15 ( 0K) [Vulka ] +//node #982 ( GET_ROWS): ffn_moe_weights-15 ( 0K) [Vulka ] use=1: ffn_moe_probs-15 (re ( 0K) [Vulka ] ffn_moe_topk-15 ( 0K) [Vulka ] +//node #983 ( RESHAPE): ffn_moe_weights-15 ( ( 0K) [Vulka ] use=2: ffn_moe_weights-15 ( 0K) [Vulka ] +//node #984 ( SUM_ROWS): ffn_moe_weights_sum- ( 0K) [Vulka ] use=1: ffn_moe_weights-15 ( ( 0K) [Vulka ] +//node #985 ( CLAMP): ffn_moe_weights_sum_ ( 0K) [Vulka ] use=1: ffn_moe_weights_sum- ( 0K) [Vulka ] +//node #986 ( DIV): ffn_moe_weights_norm ( 0K) [Vulka ] use=1: ffn_moe_weights-15 ( ( 0K) [Vulka ] ffn_moe_weights_sum_ ( 0K) [Vulka ] +//node #987 ( RESHAPE): ffn_moe_weights_norm ( 0K) [Vulka ] use=1: ffn_moe_weights_norm ( 0K) [Vulka ] +static constexpr std::initializer_list> topk_moe_early_softmax_norm_edges { + { 1, 0, 0 }, // reshape->src[0] == softmax + { 2, 0, 0 }, // argsort->src[0] == softmax + { 3, 0, 2 }, // view->src[0] == argsort + { 4, 0, 1 }, // get_rows->src[0] == reshape + { 4, 1, 3 }, // get_rows->src[1] == view + { 5, 0, 4 }, // reshape->src[0] == get_rows + { 6, 0, 5 }, // sum_rows->src[0] == reshape + { 7, 0, 6 }, // clamp->src[0] == sum_rows + { 8, 0, 5 }, // div->src[0] == reshape + { 8, 1, 7 }, // div->src[1] == clamp + { 9, 0, 8 }, // reshape->src[0] == div +}; + +// same as early_softmax_norm but ending after the get_rows +static constexpr std::initializer_list> topk_moe_early_softmax_edges { + { 1, 0, 0 }, // reshape->src[0] == softmax + { 2, 0, 0 }, // argsort->src[0] == softmax + { 3, 0, 2 }, // view->src[0] == argsort + { 4, 0, 1 }, // get_rows->src[0] == reshape + { 4, 1, 3 }, // get_rows->src[1] == view +}; + +//node #652 ( ARGSORT): ffn_moe_argsort-11 ( 0K) [Vulka ] use=1: ffn_moe_probs-11 ( 0K) [Vulka ] +//node #653 ( VIEW): ffn_moe_topk-11 ( 0K) [Vulka ] use=7: ffn_moe_argsort-11 ( 0K) [Vulka ] +//node #654 ( GET_ROWS): ffn_moe_weights-11 ( 0K) [Vulka ] use=1: ffn_moe_probs-11 (re ( 0K) [Vulka ] ffn_moe_topk-11 ( 0K) [Vulka ] +//node #655 ( RESHAPE): ffn_moe_weights-11 ( ( 0K) [Vulka ] use=1: ffn_moe_weights-11 ( 0K) [Vulka ] +//node #656 ( SOFT_MAX): node_656 ( 0K) [Vulka ] use=1: ffn_moe_weights-11 ( ( 0K) [Vulka ] +//node #657 ( RESHAPE): ffn_moe_weights_soft ( 0K) [Vulka ] use=1: node_656 ( 0K) [Vulka ] +static constexpr std::initializer_list> topk_moe_late_softmax_edges { + { 1, 0, 0 }, // view->src[0] == argsort + { 2, 1, 1 }, // get_rows->src[1] == view + { 3, 0, 2 }, // reshape->src[0] == get_rows + { 4, 0, 3 }, // soft_max->src[0] == reshape + { 5, 0, 4 }, // reshape->src[0] == soft_max +}; + +enum topk_moe_mode { + TOPK_MOE_EARLY_SOFTMAX, + TOPK_MOE_EARLY_SOFTMAX_NORM, + TOPK_MOE_LATE_SOFTMAX, + TOPK_MOE_COUNT, +}; + +static topk_moe_mode ggml_vk_num_additional_ops_to_topk_moe_mode(uint32_t num) { + topk_moe_mode mode = num == topk_moe_early_softmax_norm.size() - 1 ? TOPK_MOE_EARLY_SOFTMAX_NORM : + num == topk_moe_early_softmax.size() - 1 ? TOPK_MOE_EARLY_SOFTMAX : + TOPK_MOE_LATE_SOFTMAX; + return mode; +} + +static constexpr std::initializer_list> rope_view_set_rows_edges { + { 1, 0, 0 }, // view->src[0] == rope + { 2, 0, 1 }, // set_rows->src[0] == view +}; struct vk_device_struct { std::recursive_mutex mutex; @@ -488,6 +562,7 @@ struct vk_device_struct { vk_matmul_pipeline2 pipeline_matmul_id_f16_f32; vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id[GGML_TYPE_COUNT]; + vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_COUNT]; vk_pipeline pipeline_matmul_split_k_reduce; vk_pipeline pipeline_quantize_q8_1; @@ -575,8 +650,8 @@ struct vk_device_struct { vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16; vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512; vk_pipeline pipeline_soft_max_back_f32; - vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16; - vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16; + vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16, pipeline_rope_norm_f32_f16; + vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16, pipeline_rope_neox_f32_f16; vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16; vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16; vk_pipeline pipeline_argsort_f32[num_argsort_pipelines]; @@ -606,8 +681,7 @@ struct vk_device_struct { vk_pipeline pipeline_flash_attn_split_k_reduce; - // [2] is {!norm, norm} - vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][2]; + vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][TOPK_MOE_COUNT]; std::vector all_pipelines; @@ -955,6 +1029,8 @@ static_assert(sizeof(vk_op_multi_add_push_constants) <= 256); struct vk_op_topk_moe_push_constants { uint32_t n_rows; uint32_t n_expert_used; + float clamp_min; + float clamp_max; }; struct vk_op_add_id_push_constants { @@ -988,6 +1064,7 @@ struct vk_op_rope_push_constants { uint32_t s2; int32_t sections[4]; uint32_t is_back; + uint32_t set_rows_stride; }; struct vk_op_soft_max_push_constants { @@ -1012,6 +1089,7 @@ struct vk_op_soft_max_push_constants { struct vk_op_argsort_push_constants { uint32_t ncols; + uint32_t nrows; int32_t order; }; @@ -1497,6 +1575,10 @@ struct ggml_backend_vk_context { // number of additional consecutive nodes that are being fused with the // node currently being processed int num_additional_fused_ops {}; + // Bitmask of which fused ops need to write an intermediate value to memory. + // Bit 'i' means nodes[start_of_fusion + i] writes to memory. + // If there's no fusion, bit 0 is still set. + int fused_ops_write_mask {}; }; static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT @@ -2449,8 +2531,11 @@ static void ggml_vk_load_shaders(vk_device& device) { l_warptile_id, m_warptile_id, s_warptile_id, l_warptile_mmq, m_warptile_mmq, s_warptile_mmq, l_warptile_mmq_int, m_warptile_mmq_int, s_warptile_mmq_int, + l_warptile_mmq_int_k, m_warptile_mmq_int_k, s_warptile_mmq_int_k, l_warptile_mmq_k, m_warptile_mmq_k, s_warptile_mmq_k, - l_warptile_mmqid, m_warptile_mmqid, s_warptile_mmqid; + l_warptile_mmqid, m_warptile_mmqid, s_warptile_mmqid, + l_warptile_mmqid_int, m_warptile_mmqid_int, s_warptile_mmqid_int, + l_warptile_mmqid_int_k, m_warptile_mmqid_int_k, s_warptile_mmqid_int_k; std::array l_wg_denoms, m_wg_denoms, s_wg_denoms, l_mmq_wg_denoms, m_mmq_wg_denoms, s_mmq_wg_denoms, l_mmq_wg_denoms_k, m_mmq_wg_denoms_k, s_mmq_wg_denoms_k, @@ -2513,10 +2598,16 @@ static void ggml_vk_load_shaders(vk_device& device) { m_warptile_mmq = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 }; s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 }; + // Integer MMQ has a smaller shared memory profile, but heavier register use l_warptile_mmq_int = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 }; m_warptile_mmq_int = { 128, 64, 64, 32, subgroup_size_8, 32, 2, 2, 2, 1, subgroup_size_8 }; s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, subgroup_size_8 }; + // K-quants use even more registers, mitigate by setting WMITER to 1 + l_warptile_mmq_int_k = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 1, 4, 4, 1, subgroup_size_8 }; + m_warptile_mmq_int_k = { 128, 64, 64, 32, subgroup_size_8, 32, 1, 2, 2, 1, subgroup_size_8 }; + s_warptile_mmq_int_k = { subgroup_size_32, 32, 32, 32, 32, 32, 1, 2, 1, 1, subgroup_size_8 }; + l_warptile_id = { 128, 128, 128, 16, mul_mat_subgroup_size_16 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_16 }; m_warptile_id = { 128, 64, 64, 16, mul_mat_subgroup_size_16, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_16 }; s_warptile_id = { mul_mat_subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_16 }; @@ -2525,10 +2616,18 @@ static void ggml_vk_load_shaders(vk_device& device) { m_warptile_mmqid = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_8 }; s_warptile_mmqid = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_8 }; + l_warptile_mmqid_int = { 128, 128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, 4, 4, 1, mul_mat_subgroup_size_8 }; + m_warptile_mmqid_int = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, 2, 2, 1, mul_mat_subgroup_size_8 }; + s_warptile_mmqid_int = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, mul_mat_subgroup_size_8 }; + + l_warptile_mmqid_int_k = { 128, 128, 128, 32, mul_mat_subgroup_size_16 * 2, 64, 1, 4, 4, 1, mul_mat_subgroup_size_16 }; + m_warptile_mmqid_int_k = { 128, 64, 64, 32, mul_mat_subgroup_size_16, 32, 1, 2, 2, 1, mul_mat_subgroup_size_16 }; + s_warptile_mmqid_int_k = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 1, 2, 1, 1, mul_mat_subgroup_size_16 }; + // chip specific tuning if ((device->architecture == AMD_GCN) && (device->driver_id != vk::DriverId::eAmdProprietary)) { m_warptile_mmq = m_warptile_mmq_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 }; - m_warptile_mmqid = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 }; + m_warptile_mmqid = m_warptile_mmqid_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 }; } l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 }; @@ -2913,18 +3012,15 @@ static void ggml_vk_load_shaders(vk_device& device) { if (device->mul_mat ## ID ## _s[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ -#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ +#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \ if (device->mul_mat ## ID ## _l[TYPE]) { \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f16acc->l, #NAMELC "_f16acc_l", NAMELC ## _f16acc_len, NAMELC ## _f16acc_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->l, #NAMELC "_l", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->l, #NAMELC "_l", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ } \ if (device->mul_mat ## ID ## _m[TYPE]) { \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f16acc->m, #NAMELC "_f16acc_m", NAMELC ## _f16acc_len, NAMELC ## _f16acc_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->m, #NAMELC "_m", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->m, #NAMELC "_m", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ } \ if (device->mul_mat ## ID ## _s[TYPE]) { \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f16acc->s, #NAMELC "_f16acc_s", NAMELC ## _f16acc_len, NAMELC ## _f16acc_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->s, #NAMELC "_s", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->s, #NAMELC "_s", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ } \ // Create 2 variants, {f16,f32} accumulator @@ -2963,11 +3059,19 @@ static void ggml_vk_load_shaders(vk_device& device) { #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) if (device->integer_dot_product) { - CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0], matmul_q4_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); - CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1], matmul_q4_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); - CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0], matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); - CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1], matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); - CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0], matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0], matmul_q4_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0); + CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1], matmul_q4_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0); + CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0], matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0); + CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1], matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0); + CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0], matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0); + + CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_MXFP4], matmul_mxfp4_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0); + + CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q2_K], matmul_q2_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0); + CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q3_K], matmul_q3_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0); + CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_K], matmul_q4_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0); + CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_K], matmul_q5_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0); + CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q6_K], matmul_q6_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0); } #endif @@ -2997,6 +3101,24 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + if (device->integer_dot_product) { + CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + + CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + + CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16); + CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16); + CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16); + CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16); + CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16); + } +#endif } else { CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0); CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0); @@ -3023,6 +3145,24 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + if (device->integer_dot_product) { + CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_0], matmul_id_q4_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_1], matmul_id_q4_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_0], matmul_id_q5_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_1], matmul_id_q5_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q8_0], matmul_id_q8_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0); + + CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_MXFP4], matmul_id_mxfp4_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0); + + CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q2_K], matmul_id_q2_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q3_K], matmul_id_q3_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_K], matmul_id_q4_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_K], matmul_id_q5_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q6_K], matmul_id_q6_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0); + } +#endif } #undef CREATE_MM2 #undef CREATE_MMQ @@ -3087,6 +3227,12 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + + CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, ); } #endif @@ -3146,7 +3292,7 @@ static void ggml_vk_load_shaders(vk_device& device) { } // reusing CREATE_MM from the fp32 path if ((device->coopmat2 || device->coopmat_support) -#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) +#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) && !device->coopmat_bf16_support #endif ) { @@ -3567,21 +3713,27 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1); ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32", soft_max_back_f32_len, soft_max_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32, "rope_multi_f32", rope_multi_f32_len, rope_multi_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f32, "rope_vision_f32", rope_vision_f32_len, rope_vision_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32, "rope_multi_f32", rope_multi_f32_len, rope_multi_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f32, "rope_vision_f32", rope_vision_f32_len, rope_vision_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); if (device->float_controls_rte_fp16) { - ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_rte_len, rope_norm_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_rte_len, rope_neox_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_rte_len, rope_multi_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_rte_len, rope_vision_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_rte_len, rope_norm_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_rte_len, rope_neox_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_rte_len, rope_multi_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_rte_len, rope_vision_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32_f16, "rope_norm_f32_f16", rope_norm_f32_f16_rte_len, rope_norm_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16", rope_neox_f32_f16_rte_len, rope_neox_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); } else { - ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_len, rope_multi_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_len, rope_vision_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_len, rope_multi_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_len, rope_vision_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32_f16, "rope_norm_f32_f16", rope_norm_f32_f16_len, rope_norm_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16", rope_neox_f32_f16_len, rope_neox_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); } for (uint32_t i = 0; i < num_argsort_pipelines; ++i) { @@ -3741,8 +3893,9 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, "conv2d_dw_cwhn_f16_f32", conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) { - ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][0], "topk_moe_f32_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<pipeline_topk_moe[i][1], "topk_moe_f32_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX], "topk_moe_f32_early_softmax_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX_NORM], "topk_moe_f32_early_softmax_norm"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<pipeline_topk_moe[i][TOPK_MOE_LATE_SOFTMAX], "topk_moe_f32_late_softmax"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<device->fp16 && prec == GGML_PREC_DEFAULT) ? ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f32acc; + vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f32acc; - if (pipelines->s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr) { + if (pipelines->is_empty()) { return nullptr; } @@ -5077,6 +5230,17 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co } } + // MMQ + if (src1_type == GGML_TYPE_Q8_1) { + vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_id_q8_1[src0_type].f32acc; + + if (pipelines->is_empty()) { + return nullptr; + } + + return pipelines; + } + GGML_ASSERT(src1_type == GGML_TYPE_F32 || (ctx->device->coopmat2 && src1_type == GGML_TYPE_F16)); switch (src0_type) { @@ -5105,16 +5269,17 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co return nullptr; } + vk_matmul_pipeline2& mmp = ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type]; // XXX TODO 'prec' is not actually allowed in mul_mat_id. bool prefer_fp16acc = ctx->device->fp16 /*&& prec == GGML_PREC_DEFAULT*/; - bool support_fp16acc = ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f16acc != nullptr; - bool support_fp32acc = ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f32acc != nullptr; + bool support_fp16acc = !mmp.f16acc->is_empty(); + bool support_fp32acc = !mmp.f32acc->is_empty(); if (support_fp16acc && (prefer_fp16acc || !support_fp32acc)) { - return ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f16acc; + return mmp.f16acc; } else { GGML_ASSERT(support_fp32acc); - return ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f32acc; + return mmp.f32acc; } } @@ -5654,14 +5819,11 @@ static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& sr VK_LOG_DEBUG("ggml_vk_buffer_copy(MULTI_DEVICE, " << size << ")"); // Copy device to device ggml_vk_ensure_sync_staging_buffer(src->device, size); - ggml_vk_ensure_sync_staging_buffer(dst->device, size); // Copy to src staging buffer ggml_vk_buffer_copy(src->device->sync_staging, 0, src, src_offset, size); - // memcpy to dst staging buffer - memcpy(dst->device->sync_staging->ptr, src->device->sync_staging->ptr, size); // Copy to dst buffer - ggml_vk_buffer_copy(dst, dst_offset, dst->device->sync_staging, 0, size); + ggml_vk_buffer_write_2d(dst, dst_offset, src->device->sync_staging->ptr, 0, size, 1); } } @@ -6882,10 +7044,19 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig; - vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? f16_type : src1->type, (ggml_prec)dst->op_params[0]); + bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0; + + // Check for mmq first + vk_matmul_pipeline mmp = quantize_y ? ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, GGML_TYPE_Q8_1, (ggml_prec)dst->op_params[0]) : nullptr; + + if (mmp == nullptr) { + // Fall back to f16 dequant mul mat + mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? f16_type : src1->type, (ggml_prec)dst->op_params[0]); + quantize_y = false; + } const bool qx_needs_dequant = mmp == nullptr || x_non_contig; - const bool qy_needs_dequant = (src1->type != f16_type && !y_f32_kernel) || y_non_contig; + const bool qy_needs_dequant = !quantize_y && ((src1->type != f16_type && !y_f32_kernel) || y_non_contig); if (qx_needs_dequant) { // Fall back to dequant + f16 mulmat @@ -6895,8 +7066,8 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& // Not implemented GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT - const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? f16_type : src0->type)); - const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8; + const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? f16_type : src0->type)); + const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && nei1 > 8; vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? f16_type : src0->type); @@ -6909,12 +7080,13 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type); const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne; - const uint64_t y_sz = y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne; + const uint64_t y_sz = quantize_y ? (y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : (y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne); const uint64_t ids_sz = nbi2; const uint64_t d_sz = sizeof(float) * d_ne; vk_pipeline to_fp16_vk_0 = nullptr; vk_pipeline to_fp16_vk_1 = nullptr; + vk_pipeline to_q8_1 = nullptr; if (x_non_contig) { to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, f16_type); @@ -6929,9 +7101,16 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT + if (quantize_y) { + to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1, true); + } + if (dryrun) { const uint64_t x_sz_upd = x_sz * ne02 * ne03; - const uint64_t y_sz_upd = y_sz * ne12 * ne13; + uint64_t y_sz_upd = y_sz * ne12 * ne13; + if (quantize_y) { + y_sz_upd = CEIL_DIV(y_sz_upd, 144) * 144; + } if ( (qx_needs_dequant && x_sz_upd > ctx->device->properties.limits.maxStorageBufferRange) || (qy_needs_dequant && y_sz_upd > ctx->device->properties.limits.maxStorageBufferRange)) { @@ -6940,7 +7119,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) { ctx->prealloc_size_x = x_sz_upd; } - if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) { + if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz_upd) { ctx->prealloc_size_y = y_sz_upd; } @@ -6952,6 +7131,9 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& if (qy_needs_dequant) { ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1); } + if (quantize_y) { + ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1); + } return; } @@ -6988,6 +7170,9 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& if (qy_needs_dequant) { d_Y = ctx->prealloc_y; GGML_ASSERT(d_Y->size >= y_sz * ne12 * ne13); + } else if (quantize_y) { + d_Y = ctx->prealloc_y; + GGML_ASSERT(d_Y->size >= CEIL_DIV(y_sz * ne12 * ne13, 144) * 144); } else { d_Y = d_Qy; y_buf_offset = qy_buf_offset; @@ -7019,6 +7204,17 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& ctx->prealloc_y_last_tensor_used = src1; } } + if (quantize_y) { + if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() || + ctx->prealloc_y_last_tensor_used != src1) { + if (ctx->prealloc_y_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0), y_ne * ne12 * ne13, true); + ctx->prealloc_y_last_pipeline_used = to_q8_1.get(); + ctx->prealloc_y_last_tensor_used = src1; + } + } uint32_t stride_batch_x = ne00*ne01; uint32_t stride_batch_y = ne10*ne11; @@ -7027,14 +7223,19 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& stride_batch_x = src0->nb[0] / ggml_type_size(src0->type); } - if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) { + if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant && !quantize_y) { stride_batch_y = src1->nb[0] / ggml_type_size(src1->type); } + uint32_t y_sz_total = y_sz * ne12 * ne13; + if (quantize_y) { + y_sz_total = CEIL_DIV(y_sz_total, 144) * 144; + } + // compute ggml_vk_matmul_id( ctx, subctx, pipeline, - { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 }, + { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz_total }, { d_D, d_buf_offset, d_sz * ne22 * ne23 }, { d_ids, ids_buf_offset, ids_sz }, ne01, ne21, ne10, ne10, ne10, ne01, stride_batch_x, stride_batch_y, ne20*ne21, @@ -7973,8 +8174,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const if (ctx->num_additional_fused_ops) { uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0]))); GGML_ASSERT(idx < num_topk_moe_pipelines); - bool with_norm = ctx->num_additional_fused_ops == topk_moe_norm.size() - 1; - return ctx->device->pipeline_topk_moe[idx][with_norm]; + topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops); + return ctx->device->pipeline_topk_moe[idx][mode]; } if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) { @@ -7992,7 +8193,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const case GGML_OP_ROPE: case GGML_OP_ROPE_BACK: { - const int mode = ((const int32_t *) dst->op_params)[2]; + const ggml_tensor *rope = ctx->num_additional_fused_ops == 2 ? dst->src[0]->src[0] : dst; + const int mode = ((const int32_t *) rope->op_params)[2]; const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; const bool is_vision = mode == GGML_ROPE_TYPE_VISION; @@ -8001,6 +8203,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_rope_neox_f32; } + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) { + return ctx->device->pipeline_rope_neox_f32_f16; + } if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { return ctx->device->pipeline_rope_neox_f16; } @@ -8022,6 +8227,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_rope_norm_f32; } + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) { + return ctx->device->pipeline_rope_norm_f32_f16; + } if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { return ctx->device->pipeline_rope_norm_f16; } @@ -8029,6 +8237,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return nullptr; } case GGML_OP_ARGSORT: + if (ctx->num_additional_fused_ops) { + uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0]))); + GGML_ASSERT(idx < num_topk_moe_pipelines); + topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops); + return ctx->device->pipeline_topk_moe[idx][mode]; + } + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) { uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0]))); return ctx->device->pipeline_argsort_f32[idx]; @@ -8224,20 +8439,22 @@ static uint32_t get_misalign_bytes(ggml_backend_vk_context * ctx, const ggml_ten return ((vk_tensor_offset(t) + t->view_offs) & (ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1));; } -template void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, T &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { +template void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, T &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) { GGML_UNUSED(p); GGML_UNUSED(src0); GGML_UNUSED(src1); GGML_UNUSED(src2); + GGML_UNUSED(src3); GGML_UNUSED(dst); static_assert(!std::is_const::value, "unexpected type"); GGML_ASSERT(!src0 || get_misalign_bytes(ctx, src0) == 0); GGML_ASSERT(!src1 || get_misalign_bytes(ctx, src1) == 0); GGML_ASSERT(!src2 || get_misalign_bytes(ctx, src2) == 0); + GGML_ASSERT(!src3 || get_misalign_bytes(ctx, src3) == 0); GGML_ASSERT(!dst || get_misalign_bytes(ctx, dst) == 0); } -template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_unary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { +template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_unary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) { const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); @@ -8245,9 +8462,10 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk GGML_UNUSED(src1); GGML_UNUSED(src2); + GGML_UNUSED(src3); } -template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_sum_rows_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { +template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_sum_rows_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) { const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); @@ -8255,9 +8473,10 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk GGML_UNUSED(src1); GGML_UNUSED(src2); + GGML_UNUSED(src3); } -template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_pad_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { +template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_pad_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) { const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); @@ -8265,9 +8484,10 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk GGML_UNUSED(src1); GGML_UNUSED(src2); + GGML_UNUSED(src3); } -template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_im2col_3d_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { +template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_im2col_3d_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) { const uint32_t a_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type); const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); @@ -8275,9 +8495,10 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk GGML_UNUSED(src0); GGML_UNUSED(src2); + GGML_UNUSED(src3); } -template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_binary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { +template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_binary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) { const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); const uint32_t b_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type); const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); @@ -8287,9 +8508,10 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk p.misalign_offsets = (a_offset << 16) | (b_offset << 8) | d_offset; GGML_UNUSED(src2); + GGML_UNUSED(src3); } -template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_upscale_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { +template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_upscale_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) { const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); @@ -8298,10 +8520,11 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk GGML_UNUSED(src1); GGML_UNUSED(src2); + GGML_UNUSED(src3); } template -static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op, PC&& pc, bool dryrun = false) { +static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst, ggml_op op, PC&& pc, bool dryrun = false) { VK_LOG_DEBUG("ggml_vk_op_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; if (src1 != nullptr) { std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; @@ -8309,6 +8532,9 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co if (src2 != nullptr) { std::cerr << "), (" << src2 << ", name=" << src2->name << ", type=" << src2->type << ", ne0=" << src2->ne[0] << ", ne1=" << src2->ne[1] << ", ne2=" << src2->ne[2] << ", ne3=" << src2->ne[3] << ", nb0=" << src2->nb[0] << ", nb1=" << src2->nb[1] << ", nb2=" << src2->nb[2] << ", nb3=" << src2->nb[3]; } + if (src3 != nullptr) { + std::cerr << "), (" << src3 << ", name=" << src3->name << ", type=" << src3->type << ", ne0=" << src3->ne[0] << ", ne1=" << src3->ne[1] << ", ne2=" << src3->ne[2] << ", ne3=" << src3->ne[3] << ", nb0=" << src3->nb[0] << ", nb1=" << src3->nb[1] << ", nb2=" << src3->nb[2] << ", nb3=" << src3->nb[3]; + } std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; std::cerr << "), " << ggml_op_name(op) << ", " << (dryrun ? "dryrun" : "") << ")"); GGML_ASSERT(op == GGML_OP_GET_ROWS || op == GGML_OP_CPY || (!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type)))); // NOLINT @@ -8335,6 +8561,13 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co const uint64_t ne23 = use_src2 ? src2->ne[3] : 0; const uint64_t ne2 = ne20 * ne21; + const bool use_src3 = src3 != nullptr; + const uint64_t ne30 = use_src3 ? src3->ne[0] : 0; + const uint64_t ne31 = use_src3 ? src3->ne[1] : 0; + const uint64_t ne32 = use_src3 ? src3->ne[2] : 0; + const uint64_t ne33 = use_src3 ? src3->ne[3] : 0; + const uint64_t ne3 = ne30 * ne31; + const uint64_t ned0 = dst->ne[0]; const uint64_t ned1 = dst->ne[1]; const uint64_t ned2 = dst->ne[2]; @@ -8365,6 +8598,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; ggml_backend_vk_buffer_context * src1_buf_ctx = use_src1 ? (ggml_backend_vk_buffer_context *)src1->buffer->context : nullptr; ggml_backend_vk_buffer_context * src2_buf_ctx = use_src2 ? (ggml_backend_vk_buffer_context *)src2->buffer->context : nullptr; + ggml_backend_vk_buffer_context * src3_buf_ctx = use_src3 ? (ggml_backend_vk_buffer_context *)src3->buffer->context : nullptr; vk_buffer d_X = nullptr; size_t x_buf_offset = 0; @@ -8372,10 +8606,13 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co size_t y_buf_offset = 0; vk_buffer d_Z = nullptr; size_t z_buf_offset = 0; + vk_buffer d_W = nullptr; + size_t w_buf_offset = 0; bool src0_uma = false; bool src1_uma = false; bool src2_uma = false; + bool src3_uma = false; if (ctx->device->uma) { ggml_vk_host_get(ctx->device, src0->data, d_X, x_buf_offset); @@ -8388,6 +8625,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co ggml_vk_host_get(ctx->device, src2->data, d_Z, z_buf_offset); src2_uma = d_Z != nullptr; } + if (use_src3) { + ggml_vk_host_get(ctx->device, src3->data, d_W, w_buf_offset); + src3_uma = d_W != nullptr; + } } vk_buffer d_D = dst_buf_ctx->dev_buffer; @@ -8409,11 +8650,17 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co z_buf_offset = vk_tensor_offset(src2) + src2->view_offs; GGML_ASSERT(d_Z != nullptr); } + if (use_src3 && !src3_uma) { + d_W = src3_buf_ctx->dev_buffer; + w_buf_offset = vk_tensor_offset(src3) + src3->view_offs; + GGML_ASSERT(d_W != nullptr); + } // Compute misalignment offset for descriptors and store it in in push constants, then align the descriptor offsets. - init_pushconst_tensor_offsets(ctx, pc, src0, src1, src2, dst); + init_pushconst_tensor_offsets(ctx, pc, src0, src1, src2, src3, dst); x_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1); y_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1); z_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1); + w_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1); d_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1); std::array elements; @@ -8470,6 +8717,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co break; case GGML_OP_ARGSORT: elements = { (uint32_t)ne00, (uint32_t)ggml_nrows(src0), 1 }; + elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]); break; case GGML_OP_IM2COL: { @@ -8614,12 +8862,13 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co break; } - uint64_t x_sz, y_sz, z_sz, d_sz; + uint64_t x_sz, y_sz, z_sz, w_sz, d_sz; if (op_supports_incontiguous) { x_sz = ggml_nbytes(src0) + get_misalign_bytes(ctx, src0); y_sz = use_src1 ? ggml_nbytes(src1) + get_misalign_bytes(ctx, src1) : 0; z_sz = use_src2 ? ggml_nbytes(src2) + get_misalign_bytes(ctx, src2) : 0; + w_sz = use_src3 ? ggml_nbytes(src3) + get_misalign_bytes(ctx, src3) : 0; d_sz = ggml_nbytes(dst) + get_misalign_bytes(ctx, dst); if (x_buf_offset + x_sz >= d_X->size) { @@ -8631,6 +8880,9 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co if (use_src2 && z_buf_offset + z_sz >= d_Z->size) { z_sz = ggml_vk_get_max_buffer_range(ctx, d_Z, z_buf_offset); } + if (use_src3 && w_buf_offset + w_sz >= d_W->size) { + w_sz = ggml_vk_get_max_buffer_range(ctx, d_W, w_buf_offset); + } if (d_buf_offset + d_sz >= d_D->size) { d_sz = ggml_vk_get_max_buffer_range(ctx, d_D, d_buf_offset); } @@ -8638,6 +8890,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co x_sz = ggml_type_size(src0->type)/ggml_blck_size(src0->type) * ne0 * ne02 * ne03; y_sz = use_src1 ? ggml_type_size(src1->type) * ne1 * ne12 * ne13 : 0; z_sz = use_src2 ? ggml_type_size(src2->type) * ne2 * ne22 * ne23 : 0; + w_sz = use_src3 ? ggml_type_size(src3->type) * ne3 * ne32 * ne33 : 0; d_sz = ggml_type_size(dst->type) * ned * ned2 * ned3; } @@ -8679,14 +8932,19 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); } else if (op == GGML_OP_ROPE || op == GGML_OP_ROPE_BACK) { // Empty src2 is possible in rope, but the shader needs a buffer - vk_subbuffer subbuf_z; + vk_subbuffer subbuf_z, subbuf_w; if (use_src2) { subbuf_z = { d_Z, z_buf_offset, z_sz }; } else { subbuf_z = { d_X, 0, x_sz }; } + if (use_src3) { + subbuf_w = { d_W, w_buf_offset, w_sz }; + } else { + subbuf_w = { d_X, 0, x_sz }; + } - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz }, subbuf_w }, pc, elements); } else if (op == GGML_OP_IM2COL || op == GGML_OP_IM2COL_3D) { if (ctx->device->shader_int64 && ctx->device->buffer_device_address) { // buffer device address path doesn't use dst buffer @@ -8702,6 +8960,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co } else if (op == GGML_OP_OPT_STEP_SGD) { // OPT_STEP_SGD works on src0, it does not need dst ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz } }, pc, elements); + } else if (use_src3) { + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_W, w_buf_offset, w_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); } else if (use_src2) { ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); } else if (use_src1) { @@ -8716,7 +8976,7 @@ static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const uint32_t src1_type_size = ggml_type_size(src1->type); const uint32_t dst_type_size = ggml_type_size(dst->type); - ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GET_ROWS, { + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_GET_ROWS, { (uint32_t)ggml_nelements(src0), (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, @@ -8736,7 +8996,7 @@ static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused int offset = dst->op_params[3] / 4; // offset in bytes - ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_ACC, { + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_ACC, { (uint32_t)ggml_nelements(src0), (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)src0->nb[3] / src0_type_size, (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, @@ -8861,7 +9121,7 @@ static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const const uint32_t src1_type_size = ggml_type_size(src1->type); const uint32_t dst_type_size = ggml_type_size(dst->type); - ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_ADD, { + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_ADD, { (uint32_t)ggml_nelements(src0), (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, @@ -8876,7 +9136,7 @@ static void ggml_vk_sub(ggml_backend_vk_context * ctx, vk_context& subctx, const const uint32_t src1_type_size = ggml_type_size(src1->type); const uint32_t dst_type_size = ggml_type_size(dst->type); - ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SUB, { + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SUB, { (uint32_t)ggml_nelements(src0), (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, @@ -8891,7 +9151,7 @@ static void ggml_vk_mul(ggml_backend_vk_context * ctx, vk_context& subctx, const const uint32_t src1_type_size = ggml_type_size(src1->type); const uint32_t dst_type_size = ggml_type_size(dst->type); - ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_MUL, { + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_MUL, { (uint32_t)ggml_nelements(src0), (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, @@ -8906,7 +9166,7 @@ static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const const uint32_t src1_type_size = ggml_type_size(src1->type); const uint32_t dst_type_size = ggml_type_size(dst->type); - ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_DIV, { + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_DIV, { (uint32_t)ggml_nelements(src0), (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, @@ -8921,7 +9181,7 @@ static void ggml_vk_add_id(ggml_backend_vk_context * ctx, vk_context& subctx, co const uint32_t src1_type_size = ggml_type_size(src1->type); const uint32_t src2_type_size = ggml_type_size(src2->type); - ggml_vk_op_f32(ctx, subctx, src0, src1, src2, dst, GGML_OP_ADD_ID, { + ggml_vk_op_f32(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_ADD_ID, { (uint32_t)dst->ne[0], (uint32_t)dst->ne[1], (uint32_t)src0->nb[1] / src0_type_size, @@ -9154,7 +9414,7 @@ static void ggml_vk_ssm_conv(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; - ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SSM_CONV, { + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SSM_CONV, { (uint32_t)src0->nb[1], (uint32_t)src0->nb[2], (uint32_t)src1->nb[1], (uint32_t)dst->nb[0], (uint32_t)dst->nb[1], (uint32_t)dst->nb[2], @@ -9272,7 +9532,7 @@ static void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& su static void ggml_vk_opt_step_sgd(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) { const size_t n = ggml_nelements(dst->src[0]); - ggml_vk_op_f32(ctx, subctx, src0, src1, src2, dst, GGML_OP_OPT_STEP_SGD, { (uint32_t)n, 0, 0.0f, 0.0f }, dryrun); + ggml_vk_op_f32(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_OPT_STEP_SGD, { (uint32_t)n, 0, 0.0f, 0.0f }, dryrun); } static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { @@ -9282,7 +9542,7 @@ static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, co const uint32_t src1_type_size = ggml_type_size(src1->type); const uint32_t dst_type_size = ggml_type_size(dst->type); - ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONCAT, { + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_CONCAT, { (uint32_t)ggml_nelements(dst), (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, @@ -9306,7 +9566,7 @@ static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, c sf1 = (float)(dst->ne[1] - 1) / (src0->ne[1] - 1); } - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UPSCALE, { + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_UPSCALE, { (uint32_t)ggml_nelements(dst), 0, 0, (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, @@ -9320,23 +9580,23 @@ static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context& subctx, con p.param1 = ggml_get_op_params_f32(dst, 0); p.param2 = ggml_get_op_params_f32(dst, 1); - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SCALE, std::move(p), dryrun); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SCALE, std::move(p), dryrun); } static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQR, vk_op_unary_push_constants_init(src0, dst), dryrun); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SQR, vk_op_unary_push_constants_init(src0, dst), dryrun); } static void ggml_vk_sqrt(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQRT, vk_op_unary_push_constants_init(src0, dst), dryrun); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SQRT, vk_op_unary_push_constants_init(src0, dst), dryrun); } static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SIN, vk_op_unary_push_constants_init(src0, dst), dryrun); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SIN, vk_op_unary_push_constants_init(src0, dst), dryrun); } static void ggml_vk_cos(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_COS, vk_op_unary_push_constants_init(src0, dst), dryrun); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_COS, vk_op_unary_push_constants_init(src0, dst), dryrun); } static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { @@ -9344,12 +9604,12 @@ static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, con p.param1 = ggml_get_op_params_f32(dst, 0); p.param2 = ggml_get_op_params_f32(dst, 1); - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CLAMP, std::move(p), dryrun); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_CLAMP, std::move(p), dryrun); } static void ggml_vk_pad(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { vk_op_pad_push_constants p = vk_op_pad_push_constants_init(src0, dst); - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_PAD, std::move(p), dryrun); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_PAD, std::move(p), dryrun); } static void ggml_vk_roll(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { @@ -9364,17 +9624,17 @@ static void ggml_vk_roll(ggml_backend_vk_context * ctx, vk_context& subctx, cons memcpy(&p.param1, &s01_packed, sizeof(float)); memcpy(&p.param2, &s23_packed, sizeof(float)); - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ROLL, std::move(p), dryrun); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ROLL, std::move(p), dryrun); } static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst)); - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT, std::move(p), dryrun); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_REPEAT, std::move(p), dryrun); } static void ggml_vk_repeat_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst)); - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT_BACK, std::move(p), dryrun); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_REPEAT_BACK, std::move(p), dryrun); } static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { @@ -9390,7 +9650,7 @@ static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const } vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ne); - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, std::move(p), dryrun); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_CPY, std::move(p), dryrun); } static void ggml_vk_set_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { @@ -9405,7 +9665,7 @@ static void ggml_vk_set_rows(ggml_backend_vk_context * ctx, vk_context& subctx, return; } - ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SET_ROWS, { + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SET_ROWS, { (uint32_t)ggml_nelements(src0), (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, @@ -9416,13 +9676,13 @@ static void ggml_vk_set_rows(ggml_backend_vk_context * ctx, vk_context& subctx, } static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { - ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SILU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun); + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SILU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun); } static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { float * op_params = (float *)dst->op_params; - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun); } static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { @@ -9433,7 +9693,7 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx const float eps = float_op_params[1]; const uint32_t group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups); - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun); } static uint32_t ggml_vk_rms_num_partials(ggml_backend_vk_context * ctx, const ggml_tensor *node) { @@ -9456,7 +9716,7 @@ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, uint32_t param3 = ctx->do_add_rms_partials ? ggml_vk_rms_num_partials(ctx, dst) : 0; - ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM, { + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_RMS_NORM, { (uint32_t)ggml_nelements(src0), (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, @@ -9473,16 +9733,16 @@ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { float * op_params = (float *)dst->op_params; - ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun); + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun); } static void ggml_vk_l2_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { float * op_params = (float *)dst->op_params; - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_L2_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_L2_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun); } static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun); } static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { @@ -9505,7 +9765,7 @@ static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const const uint32_t mode = split ? 2 : (swapped ? 1 : 0); - ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GLU, + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_GLU, { (uint32_t)ggml_nelements(dst), (uint32_t)src0->ne[0], @@ -9518,7 +9778,7 @@ static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { int32_t * op_params = (int32_t *)dst->op_params; - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] }, dryrun); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] }, dryrun); } static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) { @@ -9543,7 +9803,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - ggml_vk_op_f32(ctx, subctx, src0, src1, src2, dst, GGML_OP_SOFT_MAX, { + ggml_vk_op_f32(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_SOFT_MAX, { ncols, src1 != nullptr ? nrows_y : (uint32_t)0, (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], @@ -9559,15 +9819,17 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { float * op_params = (float *)dst->op_params; - ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), op_params[0], op_params[1] }, dryrun); + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), op_params[0], op_params[1] }, dryrun); } static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx, bool dryrun = false) { - bool with_norm = ctx->num_additional_fused_ops == topk_moe_norm.size() - 1; + topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops); ggml_tensor * logits = cgraph->nodes[node_idx + 0]->src[0]; - ggml_tensor * weights = with_norm ? cgraph->nodes[node_idx + 8] : cgraph->nodes[node_idx + 4]; - ggml_tensor * ids = cgraph->nodes[node_idx + 3]; + ggml_tensor * weights = (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) ? cgraph->nodes[node_idx + 9] : + (mode == TOPK_MOE_EARLY_SOFTMAX) ? cgraph->nodes[node_idx + 4] : + cgraph->nodes[node_idx + 5]; + ggml_tensor * ids = (mode == TOPK_MOE_LATE_SOFTMAX) ? cgraph->nodes[node_idx + 1] : cgraph->nodes[node_idx + 3]; GGML_ASSERT(logits->type == GGML_TYPE_F32); GGML_ASSERT(weights->type == GGML_TYPE_F32); @@ -9626,9 +9888,14 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, GGML_ASSERT(d_ids != nullptr); } - vk_op_topk_moe_push_constants pc; + vk_op_topk_moe_push_constants pc {}; pc.n_rows = n_rows; pc.n_expert_used = n_expert_used; + if (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) { + ggml_tensor * clamp = cgraph->nodes[node_idx + 7]; + pc.clamp_min = ggml_get_op_params_f32(clamp, 0); + pc.clamp_max = ggml_get_op_params_f32(clamp, 1); + } GGML_ASSERT(n_expert_used <= n_experts); @@ -9643,7 +9910,12 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, }, pc, elements); } -static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool backprop, bool dryrun = false) { +static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_cgraph * cgraph, int node_idx, bool backprop, bool dryrun = false) { + ggml_tensor * dst = cgraph->nodes[node_idx]; + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + const ggml_tensor * src2 = dst->src[2]; + const ggml_tensor * src3 = nullptr; const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; // const int n_ctx = ((int32_t *) dst->op_params)[3]; @@ -9667,11 +9939,20 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons uint32_t s1 = src0->nb[1] / ggml_type_size(src0->type); uint32_t s2 = src0->nb[2] / ggml_type_size(src0->type); - ggml_vk_op_f32(ctx, subctx, src0, src1, src2, dst, GGML_OP_ROPE, { + uint32_t set_rows_stride = 0; + // Fused rope + view + set_rows passes the set_rows destination stride in set_rows_stride + // and overrides the dst and sets src3=row_indices + if (ctx->num_additional_fused_ops > 0) { + set_rows_stride = cgraph->nodes[node_idx + 2]->nb[1] / ggml_type_size(cgraph->nodes[node_idx + 2]->type); + src3 = cgraph->nodes[node_idx + 2]->src[1]; + dst = cgraph->nodes[node_idx + 2]; + } + + ggml_vk_op_f32(ctx, subctx, src0, src1, src2, src3, dst, GGML_OP_ROPE, { (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1], freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale, src2 != nullptr, (uint32_t)src0->ne[2], s1, s2, - { sections[0], sections[1], sections[2], sections[3] }, backprop + { sections[0], sections[1], sections[2], sections[3] }, backprop, set_rows_stride, }, dryrun); } @@ -9679,35 +9960,37 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c int32_t * op_params = (int32_t *)dst->op_params; uint32_t ncols = src0->ne[0]; + uint32_t nrows = ggml_nrows(src0); - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ARGSORT, { + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGSORT, { ncols, + nrows, op_params[0], }, dryrun); } static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, ggml_nelements(src0)); - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM, p, dryrun); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SUM, p, dryrun); } static void ggml_vk_sum_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]); - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, p, dryrun); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, p, dryrun); } static void ggml_vk_mean(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]); p.weight = 1.0f / (float)src0->ne[0]; - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_MEAN, p, dryrun); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_MEAN, p, dryrun); } static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f }, dryrun); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f }, dryrun); } static void ggml_vk_count_equal(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { - ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_COUNT_EQUAL, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun); + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_COUNT_EQUAL, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun); } static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { @@ -9740,7 +10023,7 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co const vk::DeviceAddress dst_addr = d_buf->bda_addr + vk_tensor_offset(dst) + dst->view_offs; - ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_IM2COL, { + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_IM2COL, { dst_addr, batch_offset, offset_delta, IC, IW, IH, OW, OH, KW, KH, @@ -9813,7 +10096,7 @@ static void ggml_vk_im2col_3d(ggml_backend_vk_context * ctx, vk_context& subctx, pc.OH_OW_IC_KD_KH_KW = OH*OW*IC*KD*KH*KW; pc.OW_IC_KD_KH_KW = OW*IC*KD*KH*KW; - ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_IM2COL_3D, std::move(pc), dryrun); + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_IM2COL_3D, std::move(pc), dryrun); } static void ggml_vk_timestep_embedding(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { @@ -9821,7 +10104,7 @@ static void ggml_vk_timestep_embedding(ggml_backend_vk_context * ctx, vk_context const uint32_t max_period = dst->op_params[1]; const uint32_t nb1 = dst->nb[1] / ggml_type_size(dst->type); - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_TIMESTEP_EMBEDDING, { + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_TIMESTEP_EMBEDDING, { nb1, dim, max_period, }, dryrun); } @@ -9854,7 +10137,7 @@ static void ggml_vk_conv_transpose_1d(ggml_backend_vk_context * ctx, vk_context& p.nb1 = static_cast(nb1 / nb0); p.s0 = static_cast(s0); - ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_TRANSPOSE_1D, std::move(p), dryrun); + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_CONV_TRANSPOSE_1D, std::move(p), dryrun); } static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { @@ -9877,7 +10160,7 @@ static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, c const uint32_t parallel_elements = N * OC * OH * OW; - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_POOL_2D, { + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_POOL_2D, { IW, IH, OW, OH, OC, parallel_elements, op, @@ -9931,7 +10214,7 @@ static void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context & subctx, GGML_ASSERT(ne03 == ne2); GGML_ASSERT(ne02 == ne12); - ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_2D, std::move(p), dryrun); + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_CONV_2D, std::move(p), dryrun); } static void ggml_vk_conv_transpose_2d(ggml_backend_vk_context * ctx, vk_context & subctx, const ggml_tensor * src0, @@ -9980,7 +10263,7 @@ static void ggml_vk_conv_transpose_2d(ggml_backend_vk_context * ctx, vk_context GGML_ASSERT(ne02 == ne2); GGML_ASSERT(ne03 == ne12); - ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_TRANSPOSE_2D, std::move(p), dryrun); + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_CONV_TRANSPOSE_2D, std::move(p), dryrun); } static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { @@ -10004,12 +10287,12 @@ static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx GGML_ASSERT(src0->ne[3] == p.channels); GGML_ASSERT(src1->ne[3] == p.batches); - ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_2D_DW, std::move(p), dryrun); + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_CONV_2D_DW, std::move(p), dryrun); } static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { const float * op_params = (const float *)dst->op_params; - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f }, dryrun); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f }, dryrun); } #ifdef GGML_VULKAN_RUN_TESTS @@ -11135,7 +11418,6 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX_BACK: - case GGML_OP_ROPE: case GGML_OP_ROPE_BACK: case GGML_OP_ARGSORT: case GGML_OP_SUM: @@ -11209,9 +11491,12 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr // nodes require synchronization. for (int32_t i = 0; i < ctx->num_additional_fused_ops + 1 && !need_sync; ++i) { const ggml_tensor *cur_node = cgraph->nodes[node_idx + i]; - if (overlaps_unsynced(cur_node, ctx->unsynced_nodes_read) || overlaps_unsynced(cur_node, ctx->unsynced_nodes_written)) { - need_sync = true; - break; + // If the node actually writes to memory, then check if it needs to sync + if (ctx->fused_ops_write_mask & (1 << i)) { + if (overlaps_unsynced(cur_node, ctx->unsynced_nodes_read) || overlaps_unsynced(cur_node, ctx->unsynced_nodes_written)) { + need_sync = true; + break; + } } for (uint32_t j = 0; j < GGML_MAX_SRC; ++j) { if (!cur_node->src[j]) { @@ -11223,7 +11508,13 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr } } } + +#define ENABLE_SYNC_LOGGING 0 + if (need_sync) { +#if ENABLE_SYNC_LOGGING + std::cerr << "sync" << std::endl; +#endif ctx->unsynced_nodes_written.clear(); ctx->unsynced_nodes_read.clear(); ggml_vk_sync_buffers(ctx, compute_ctx); @@ -11232,7 +11523,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr for (int32_t i = 0; i < ctx->num_additional_fused_ops + 1; ++i) { const ggml_tensor *cur_node = cgraph->nodes[node_idx + i]; // Multiple outputs could be written, e.g. in topk_moe. Add them all to the list. - ctx->unsynced_nodes_written.push_back(cur_node); + if (ctx->fused_ops_write_mask & (1 << i)) { + ctx->unsynced_nodes_written.push_back(cur_node); + } for (uint32_t j = 0; j < GGML_MAX_SRC; ++j) { if (!cur_node->src[j]) { continue; @@ -11241,6 +11534,18 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr } } } +#if ENABLE_SYNC_LOGGING + if (!dryrun) { + for (int i = 0; i < ctx->num_additional_fused_ops + 1; ++i) { + auto *n = cgraph->nodes[node_idx + i]; + std::cerr << node_idx + i << " " << ggml_op_name(n->op) << " " << n->name; + if (n->op == GGML_OP_GLU) { + std::cerr << " " << ggml_glu_op_name(ggml_get_glu_op(n)) << " " << (n->src[1] ? "split" : "single") << " "; + } + std::cerr << std::endl; + } + } +#endif switch (node->op) { case GGML_OP_REPEAT: @@ -11411,15 +11716,19 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr break; case GGML_OP_ROPE: - ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node, false, dryrun); + ggml_vk_rope(ctx, compute_ctx, cgraph, node_idx, false, dryrun); break; case GGML_OP_ROPE_BACK: - ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node, true, dryrun); + ggml_vk_rope(ctx, compute_ctx, cgraph, node_idx, true, dryrun); break; case GGML_OP_ARGSORT: - ggml_vk_argsort(ctx, compute_ctx, src0, node, dryrun); + if (ctx->num_additional_fused_ops) { + ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx, dryrun); + } else { + ggml_vk_argsort(ctx, compute_ctx, src0, node, dryrun); + } break; case GGML_OP_SUM: @@ -12217,31 +12526,28 @@ static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, st } static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, - int node_idx, bool with_norm) { + int node_idx, topk_moe_mode mode) { - if (with_norm) { - if (node_idx + (int)topk_moe_norm.size() > cgraph->n_nodes) { - return false; - } - for (size_t i = 0; i < topk_moe_norm.size(); ++i) { - if (cgraph->nodes[node_idx + i]->op != topk_moe_norm[i]) { - return false; - } - } - } else { - if (node_idx + (int)topk_moe.size() > cgraph->n_nodes) { - return false; - } - for (size_t i = 0; i < topk_moe.size(); ++i) { - if (cgraph->nodes[node_idx + i]->op != topk_moe[i]) { - return false; - } - } + const ggml_tensor * softmax; + const ggml_tensor * weights; + + switch (mode) { + case TOPK_MOE_EARLY_SOFTMAX_NORM: + softmax = cgraph->nodes[node_idx + 0]; + weights = cgraph->nodes[node_idx + 9]; + break; + case TOPK_MOE_EARLY_SOFTMAX: + softmax = cgraph->nodes[node_idx + 0]; + weights = cgraph->nodes[node_idx + 4]; + break; + case TOPK_MOE_LATE_SOFTMAX: + softmax = cgraph->nodes[node_idx + 4]; + weights = cgraph->nodes[node_idx + 5]; + break; + default: + return false; } - const ggml_tensor * softmax = cgraph->nodes[node_idx + 0]; - const ggml_tensor * weights = with_norm ? cgraph->nodes[node_idx + 8] : cgraph->nodes[node_idx + 4]; - const float * op_params = (const float *)softmax->op_params; float scale = op_params[0]; @@ -12266,60 +12572,6 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc return false; } - // Check that the nodes don't have any unexpected uses - const ggml_tensor * reshape1 = cgraph->nodes[node_idx + 1]; - const ggml_tensor * argsort = cgraph->nodes[node_idx + 2]; - const ggml_tensor * view = cgraph->nodes[node_idx + 3]; - const ggml_tensor * get_rows = cgraph->nodes[node_idx + 4]; - const ggml_tensor * reshape5 = with_norm ? cgraph->nodes[node_idx + 5] : nullptr; - const ggml_tensor * sum_rows = with_norm ? cgraph->nodes[node_idx + 6] : nullptr; - const ggml_tensor * div = with_norm ? cgraph->nodes[node_idx + 7] : nullptr; - const ggml_tensor * reshape8 = with_norm ? cgraph->nodes[node_idx + 8] : nullptr; - - // softmax is used by reshape and argsort - if (ggml_node_get_use_count(cgraph, node_idx) != 2 || - reshape1->src[0] != softmax || - argsort->src[0] != softmax) { - return false; - } - // reshape is used by get_rows - if (ggml_node_get_use_count(cgraph, node_idx + 1) != 1 || - get_rows->src[0] != reshape1) { - return false; - } - // argsort is used by view - if (ggml_node_get_use_count(cgraph, node_idx + 2) != 1 || - view->src[0] != argsort) { - return false; - } - // view is written (via argsort), we can skip checking it - - if (with_norm) { - // get_rows is used by reshape - if (ggml_node_get_use_count(cgraph, node_idx + 4) != 1 || - reshape5->src[0] != get_rows) { - return false; - } - - // reshape is used by sum_rows and div - if (ggml_node_get_use_count(cgraph, node_idx + 5) != 2 || - sum_rows->src[0] != reshape5 || - div->src[0] != reshape5) { - return false; - } - - // sum_rows is used by div - if (ggml_node_get_use_count(cgraph, node_idx + 6) != 1 || - div->src[1] != sum_rows) { - return false; - } - - // div/reshape are written - if (reshape8->src[0] != div) { - return false; - } - } - if (!ctx->device->subgroup_arithmetic || !ctx->device->subgroup_shuffle || !ctx->device->subgroup_require_full_support || @@ -12330,6 +12582,41 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc return true; } +static bool ggml_vk_can_fuse_rope_set_rows(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, + int node_idx) { + GGML_UNUSED(ctx); + const ggml_tensor *rope = cgraph->nodes[node_idx + 0]; + const ggml_tensor *view = cgraph->nodes[node_idx + 1]; + const ggml_tensor *set_rows = cgraph->nodes[node_idx + 2]; + + // ne3 not tested + if (rope->src[0]->ne[3] != 1) { + return false; + } + + if (set_rows->type != GGML_TYPE_F32 && set_rows->type != GGML_TYPE_F16) { + return false; + } + + if (set_rows->src[1]->type != GGML_TYPE_I64) { + return false; + } + + // The view should flatten two dims of rope into one dim + if (!ggml_is_contiguous(view) || + view->ne[0] != rope->ne[0] * rope->ne[1]) { + return false; + } + + // Only norm/neox shaders have the fusion code + const int mode = ((const int32_t *) rope->op_params)[2]; + if (mode != GGML_ROPE_TYPE_NORMAL && mode != GGML_ROPE_TYPE_NEOX) { + return false; + } + + return true; +} + static uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx) { const ggml_tensor *first_node = cgraph->nodes[node_idx]; @@ -12405,10 +12692,22 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg ctx->num_additional_fused_ops = num_adds - 1; } else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { ctx->num_additional_fused_ops = 1; - } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, true)) { - ctx->num_additional_fused_ops = topk_moe_norm.size() - 1; - } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, false)) { - ctx->num_additional_fused_ops = topk_moe.size() - 1; + } else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 2 }) && + ggml_check_edges(cgraph, i, rope_view_set_rows_edges) && + ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i)) { + ctx->num_additional_fused_ops = 2; + } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 9 }) && + ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) && + ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) { + ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1; + } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) && + ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) && + ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) { + ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1; + } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) && + ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) && + ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) { + ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1; } } ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false); @@ -12506,12 +12805,31 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg ctx->num_additional_fused_ops = num_adds - 1; } else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { ctx->num_additional_fused_ops = 1; - } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, true)) { - ctx->num_additional_fused_ops = topk_moe_norm.size() - 1; - } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, false)) { - ctx->num_additional_fused_ops = topk_moe.size() - 1; + } else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 2 }) && + ggml_check_edges(cgraph, i, rope_view_set_rows_edges) && + ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i)) { + ctx->num_additional_fused_ops = 2; + } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 9 }) && + ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) && + ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) { + ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1; + // view of argsort writes to memory + ctx->fused_ops_write_mask |= 1 << 3; + } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) && + ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) && + ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) { + ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1; + // view of argsort writes to memory + ctx->fused_ops_write_mask |= 1 << 3; + } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) && + ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) && + ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) { + ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1; + // view of argsort writes to memory + ctx->fused_ops_write_mask |= 1 << 1; } } + ctx->fused_ops_write_mask |= 1 << ctx->num_additional_fused_ops; // Signal the almost_ready fence when the graph is mostly complete (< 20% remaining) bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5; @@ -12557,6 +12875,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg } i += ctx->num_additional_fused_ops; ctx->num_additional_fused_ops = 0; + ctx->fused_ops_write_mask = 0; } if (vk_perf_logger_enabled) { @@ -12642,25 +12961,44 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * while (first_unused < graph->n_nodes) { std::vector current_set; - // Avoid reordering topk_moe_norm - if (first_unused + (int)topk_moe_norm.size() <= graph->n_nodes) { - bool is_topk_moe_norm = true; - for (size_t j = 0; j < topk_moe_norm.size(); ++j) { - if (graph->nodes[first_unused + j]->op != topk_moe_norm[j] || used[first_unused + j]) { - is_topk_moe_norm = false; + // Check for fusion patterns and avoid reordering them + auto const &match_pattern = [&](const std::initializer_list &pattern, int start) -> bool { + if (start + (int)pattern.size() <= graph->n_nodes) { + bool is_pattern = true; + for (size_t j = 0; j < pattern.size(); ++j) { + if (graph->nodes[start + j]->op != pattern.begin()[j] || used[start + j]) { + is_pattern = false; + } } + return is_pattern; } - if (is_topk_moe_norm) { - for (size_t j = 0; j < topk_moe_norm.size(); ++j) { + return false; + }; + + auto const &keep_pattern = [&](const std::initializer_list &pattern) -> bool { + if (match_pattern(pattern, first_unused)) { + for (size_t j = 0; j < pattern.size(); ++j) { new_order.push_back(graph->nodes[first_unused + j]); used[first_unused + j] = true; } while (first_unused < graph->n_nodes && used[first_unused]) { first_unused++; } - continue; + return true; } + return false; + }; + + if (keep_pattern(topk_moe_early_softmax_norm)) { + continue; } + if (keep_pattern(topk_moe_early_softmax)) { + continue; + } + if (keep_pattern(topk_moe_late_softmax)) { + continue; + } + // First, grab the next unused node. current_set.push_back(first_unused); @@ -12678,6 +13016,12 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * if (is_empty(graph->nodes[j])) { continue; } + // Don't pull forward nodes from fusion patterns + if (match_pattern(topk_moe_early_softmax_norm, j) || + match_pattern(topk_moe_early_softmax, j) || + match_pattern(topk_moe_late_softmax, j)) { + continue; + } bool ok = true; for (int c = first_unused; c < j; ++c) { if (!used[c] && @@ -12689,6 +13033,32 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * } if (ok) { current_set.push_back(j); + // Look for ROPE + VIEW + SET_ROWS and make them consecutive + if (graph->nodes[j]->op == GGML_OP_ROPE) { + int view_idx = -1; + int set_rows_idx = -1; + for (int k = j+1; k < std::min(j + 10, graph->n_nodes); ++k) { + if (view_idx == -1 && + graph->nodes[k]->op == GGML_OP_VIEW && + graph->nodes[k]->src[0] == graph->nodes[j]) { + view_idx = k; + continue; + } + if (view_idx != -1 && + set_rows_idx == -1 && + graph->nodes[k]->op == GGML_OP_SET_ROWS && + graph->nodes[k]->src[0] == graph->nodes[view_idx]) { + set_rows_idx = k; + break; + } + } + if (set_rows_idx != -1) { + current_set.push_back(view_idx); + current_set.push_back(set_rows_idx); + used[view_idx] = true; + used[set_rows_idx] = true; + } + } } } // Second pass grabs view nodes. diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp index c81b84452..c4e68bc02 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp @@ -14,6 +14,7 @@ layout (binding = 1) buffer D {int data_d[];}; layout (push_constant) uniform parameter { uint ncols; + uint nrows; uint order; } p; @@ -26,10 +27,9 @@ void swap(uint idx0, uint idx1) { dst_row[idx1] = tmp; } -void argsort(bool needs_bounds_check) { +void argsort(bool needs_bounds_check, const uint row) { // bitonic sort const int col = int(gl_LocalInvocationID.x); - const uint row = gl_WorkGroupID.y; const uint row_offset = row * p.ncols; @@ -72,8 +72,16 @@ void argsort(bool needs_bounds_check) { void main() { if (p.ncols == BLOCK_SIZE) { - argsort(false); + uint row = gl_WorkGroupID.y; + while (row < p.nrows) { + argsort(false, row); + row += gl_WorkGroupSize.y * gl_NumWorkGroups.y; + } } else { - argsort(true); + uint row = gl_WorkGroupID.y; + while (row < p.nrows) { + argsort(true, row); + row += gl_WorkGroupSize.y * gl_NumWorkGroups.y; + } } } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl index 0d98f5a9d..09676a623 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl @@ -437,7 +437,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) { #if defined(DATA_A_MXFP4) vec2 dequantize(uint ib, uint iqs, uint a_offset) { const uint vui = uint(data_a[a_offset + ib].qs[iqs]); - return vec2(kvalues_mxfp4[vui & 0xF], kvalues_mxfp4[vui >> 4]); + return vec2(kvalues_mxfp4[vui & 0xF], kvalues_mxfp4[vui >> 4]) * 0.5; } vec4 dequantize4(uint ib, uint iqs, uint a_offset) { vec2 v0 = dequantize(ib, iqs, a_offset); @@ -488,9 +488,9 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) { const uvec2 qs = uvec2(data_a[a_offset + ib].qs[qsi], data_a[a_offset + ib].qs[qsi + 1]); const uint scales = data_a[a_offset + ib].scales[scalesi]; - const vec2 d = vec2(data_a[a_offset + ib].d); + const vec2 dm = vec2(data_a[a_offset + ib].dm); - return d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4); + return dm.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - dm.y * float(scales >> 4); } vec2 get_dm(uint ib, uint a_offset) { return vec2(1, 0); @@ -529,7 +529,7 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) { const uint is = 2 * n + b; // 0..7 const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126 - const vec2 loadd = vec2(data_a[a_offset + ib].d); + const vec2 loadd = vec2(data_a[a_offset + ib].dm); const uint scidx0 = (is < 4) ? is : (is + 4); const uint scidx1 = (is < 4) ? is : (is - 4); @@ -567,7 +567,7 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) { const uint8_t hm = uint8_t(1 << (iqs / 16)); - const vec2 loadd = vec2(data_a[a_offset + ib].d); + const vec2 loadd = vec2(data_a[a_offset + ib].dm); const uint scidx0 = (is < 4) ? is : (is + 4); const uint scidx1 = (is < 4) ? is : (is - 4); diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl index 67baedf7c..8ac6482dc 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl @@ -120,7 +120,7 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ2 float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { decodeBufQ2_K_packed16 bl16 = decodeBufQ2_K_packed16(bl); - const f16vec2 d = bl.block.d; + const f16vec2 dm = bl.block.dm; const uint idx = coordInBlock[1]; const uint scalesi = (idx & 0xF0) >> 4; // 0..15 @@ -131,7 +131,7 @@ float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2 qs = unpack8(qs)[idx & 1]; const uint scales = bl.block.scales[scalesi]; - float16_t ret = d.x * float16_t(scales & 0xF) * float16_t(qs) - d.y * float16_t(scales >> 4); + float16_t ret = dm.x * float16_t(scales & 0xF) * float16_t(qs) - dm.y * float16_t(scales >> 4); return ret; } @@ -680,7 +680,7 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords uint32_t qs = bl.block.qs[iqs]; qs >>= shift; qs &= 0xF; - float16_t ret = float16_t(kvalues_mxfp4[qs] * d); + float16_t ret = float16_t(kvalues_mxfp4[qs] * d * 0.5); return ret; } #endif diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp index ffba5a77d..3194ba291 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp @@ -26,7 +26,7 @@ void main() { const float d = e8m0_to_fp32(data_a[ib].e); [[unroll]] for (uint l = 0; l < 8; ++l) { - data_b[b_idx + l + 0] = D_TYPE(d * kvalues_mxfp4[data_a[ib].qs[q_idx + l] & 0xF]); - data_b[b_idx + l + 16] = D_TYPE(d * kvalues_mxfp4[data_a[ib].qs[q_idx + l] >> 4]); + data_b[b_idx + l + 0] = D_TYPE(d * 0.5 * float(kvalues_mxfp4[data_a[ib].qs[q_idx + l] & 0xF])); + data_b[b_idx + l + 16] = D_TYPE(d * 0.5 * float(kvalues_mxfp4[data_a[ib].qs[q_idx + l] >> 4])); } } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp index 58dc2e5df..dc05a7834 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp @@ -24,8 +24,8 @@ void main() { const uint ql_idx = 32 * ip + il; const uint8_t qs = data_a[i].qs[32 * ip + il]; - FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x); - FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y); + FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].dm.x); + FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].dm.y); data_b[y_idx + 0] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+0] & 0xF) * ((qs >> 0) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+0] >> 4)); data_b[y_idx + 32] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+2] & 0xF) * ((qs >> 2) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+2] >> 4)); data_b[y_idx + 64] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+4] & 0xF) * ((qs >> 4) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+4] >> 4)); diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp index 8b7be557e..0f23dc0a3 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp @@ -20,8 +20,8 @@ void main() { const uint is = 2 * il; const uint n = 4; - const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].d.x); - const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].d.y); + const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].dm.x); + const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].dm.y); const uint y_idx = ib * QUANT_K + 64 * il + n * ir; const uint qs_idx = 32*il + n * ir; diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp index 6bc04670f..970469a60 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp @@ -19,8 +19,8 @@ void main() { const uint ir = tid % 16; const uint is = 2 * il; - const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].d.x); - const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].d.y); + const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].dm.x); + const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].dm.y); const uint y_idx = ib * QUANT_K + 64 * il + 2 * ir; const uint qs_idx = 32*il + 2 * ir; diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp index 03ed25d3b..14093c0de 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp @@ -41,9 +41,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303)); const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303)); - vec2 d = vec2(data_a[ib0 + i].d); - const FLOAT_TYPE dall = FLOAT_TYPE(d.x); - const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); + const FLOAT_TYPE_VEC2 dm = vec2(data_a[ib0 + i].dm); [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]); @@ -75,7 +73,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, fma(FLOAT_TYPE(b96[l]), sccache2[csel][ix][6 + 8*v_im], fma(FLOAT_TYPE(b112[l]), sccache2[csel][ix][7 + 8*v_im], sum2)))))))); } - temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n])); + temp[j][n] = fma(dm.x, sum1, fma(-dm.y, sum2, temp[j][n])); } } } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp index 21d07d2e5..49d91ad59 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp @@ -14,9 +14,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im, [[unroll]] for (uint n = 0; n < num_rows; ++n) { const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; - vec2 d = vec2(data_a[ib0 + i].d); - const FLOAT_TYPE dall = FLOAT_TYPE(d.x); - const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); + const FLOAT_TYPE_VEC2 dm = FLOAT_TYPE_VEC2(data_a[ib0 + i].dm); const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2]; @@ -81,7 +79,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im, fma(FLOAT_TYPE(by10.y), sc2, fma(FLOAT_TYPE(by132.y), sc3, fma(FLOAT_TYPE(by20.y), sc6, fma(FLOAT_TYPE(by232.y), sc7, fma(FLOAT_TYPE(by10.z), sc2, fma(FLOAT_TYPE(by132.z), sc3, fma(FLOAT_TYPE(by20.z), sc6, fma(FLOAT_TYPE(by232.z), sc7, fma(FLOAT_TYPE(by10.w), sc2, fma(FLOAT_TYPE(by132.w), sc3, fma(FLOAT_TYPE(by20.w), sc6, FLOAT_TYPE(by232.w) * sc7))))))))))))))); - temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n])); + temp[j][n] = fma(dm.x, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dm.y, smin, temp[j][n])); } } } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp index 9e46c89a1..0d61b4966 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp @@ -14,9 +14,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im, [[unroll]] for (uint n = 0; n < num_rows; ++n) { const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; - vec2 d = vec2(data_a[ib0 + i].d); - const FLOAT_TYPE dall = FLOAT_TYPE(d.x); - const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); + const FLOAT_TYPE_VEC2 dm = FLOAT_TYPE_VEC2(data_a[ib0 + i].dm); const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2]; @@ -113,7 +111,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im, fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3, fma(FLOAT_TYPE(by20.x) + FLOAT_TYPE(by20.y) + FLOAT_TYPE(by216.x) + FLOAT_TYPE(by216.y), sc6, (FLOAT_TYPE(by232.x) + FLOAT_TYPE(by232.y) + FLOAT_TYPE(by248.x) + FLOAT_TYPE(by248.y)) * sc7))); - temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n])); + temp[j][n] = fma(dm.x, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dm.y, smin, temp[j][n])); } } } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index a20788c4b..d260969f0 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -120,81 +120,11 @@ shared FLOAT_TYPE_VEC2 buf_b[BN * SHMEM_STRIDE]; #define NUM_WARPS (BLOCK_SIZE / WARP) -#ifdef MUL_MAT_ID -shared u16vec2 row_ids[BN]; -uint _ne1; - -#ifdef MUL_MAT_ID_USE_SUBGROUPS -shared uvec4 ballots_sh[NUM_WARPS]; - -void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) { - _ne1 = 0; - uint num_elements = p.nei1 * p.nei0; - uint nei0shift = findLSB(p.nei0); - - uint ids[16]; - uint iter = 0; - - for (uint j = 0; j < num_elements; j += BLOCK_SIZE) { - // prefetch up to 16 elements - if (iter == 0) { - [[unroll]] for (uint k = 0; k < 16; ++k) { - uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE; - bool in_range = i < num_elements; - uint ii1; - if (nei0_is_pow2) { - ii1 = i >> nei0shift; - } else { - ii1 = i / p.nei0; - } - uint ii0 = i - ii1 * p.nei0; - ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0; - } - } - uint i = j + gl_LocalInvocationIndex; - bool in_range = i < num_elements; - uint ii1; - if (nei0_is_pow2) { - ii1 = i >> nei0shift; - } else { - ii1 = i / p.nei0; - } - uint ii0 = i - ii1 * p.nei0; - uint id = ids[iter++]; - uvec4 ballot = subgroupBallot(in_range && id == expert_idx); - - ballots_sh[gl_SubgroupID] = ballot; - barrier(); - - uint subgroup_base = 0; - uint total = 0; - for (uint k = 0; k < gl_NumSubgroups; ++k) { - if (k == gl_SubgroupID) { - subgroup_base = total; - } - total += subgroupBallotBitCount(ballots_sh[k]); - } - barrier(); - - uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot); - if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) { - row_ids[_ne1 + idx - ic * BN] = u16vec2(ii0, ii1); - } - _ne1 += total; - iter &= 15; - if (_ne1 >= (ic + 1) * BN) { - break; - } - } - barrier(); -} -#endif // MUL_MAT_ID_USE_SUBGROUPS -#endif // MUL_MAT_ID - #ifdef COOPMAT shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS]; #endif +#include "mul_mm_id_funcs.glsl" #include "mul_mm_funcs.glsl" void main() { diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl index 0ebfbd646..ee5ded2e8 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl @@ -134,15 +134,15 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const uint ib = idx / 128; // 2 values per idx const uint iqs = idx % 128; // 0..127 - const uint qsi = (iqs / 64) * 32 + (iqs % 16) * 2; // 0,2,4..30 + const uint qsi = (iqs / 64) * 16 + (iqs % 16); // 0..15 const uint scalesi = iqs / 8; // 0..15 const uint qsshift = ((iqs % 64) / 16) * 2; // 0,2,4,6 - const uvec2 qs = uvec2(data_a[ib].qs[qsi], data_a[ib].qs[qsi + 1]); + const uvec2 qs = uvec2(unpack8(data_a_packed16[ib].qs[qsi])); const uint scales = data_a[ib].scales[scalesi]; - const vec2 d = vec2(data_a[ib].d); + const vec2 dm = vec2(data_a[ib].dm); - const vec2 v = d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4); + const vec2 v = dm.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - dm.y * float(scales >> 4); buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy); #elif defined(DATA_A_Q3_K) @@ -179,7 +179,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const uint is = 2 * n + b; // 0..7 const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126 - const vec2 loadd = vec2(data_a[ib].d); + const vec2 loadd = vec2(data_a[ib].dm); const uint scidx0 = (is < 4) ? is : (is + 4); const uint scidx1 = (is < 4) ? is : (is - 4); @@ -215,7 +215,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const uint8_t hm = uint8_t(1 << (iqs / 16)); - const vec2 loadd = vec2(data_a[ib].d); + const vec2 loadd = vec2(data_a[ib].dm); const uint scidx0 = (is < 4) ? is : (is + 4); const uint scidx1 = (is < 4) ? is : (is - 4); @@ -468,7 +468,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const uint ib = idx / 8; const uint iqs = (idx & 0x07) * 2; - const float d = e8m0_to_fp32(data_a[ib].e); + const float d = e8m0_to_fp32(data_a[ib].e) * 0.5; const uint vui = uint(data_a[ib].qs[iqs]); const uint vui2 = uint(data_a[ib].qs[iqs+1]); diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl new file mode 100644 index 000000000..1d0e84ac9 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl @@ -0,0 +1,70 @@ +#ifdef MUL_MAT_ID +shared u16vec2 row_ids[BN]; +uint _ne1; + +#ifdef MUL_MAT_ID_USE_SUBGROUPS +shared uvec4 ballots_sh[NUM_WARPS]; + +void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) { + _ne1 = 0; + uint num_elements = p.nei1 * p.nei0; + uint nei0shift = findLSB(p.nei0); + + uint ids[16]; + uint iter = 0; + + for (uint j = 0; j < num_elements; j += BLOCK_SIZE) { + // prefetch up to 16 elements + if (iter == 0) { + [[unroll]] for (uint k = 0; k < 16; ++k) { + uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE; + bool in_range = i < num_elements; + uint ii1; + if (nei0_is_pow2) { + ii1 = i >> nei0shift; + } else { + ii1 = i / p.nei0; + } + uint ii0 = i - ii1 * p.nei0; + ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0; + } + } + uint i = j + gl_LocalInvocationIndex; + bool in_range = i < num_elements; + uint ii1; + if (nei0_is_pow2) { + ii1 = i >> nei0shift; + } else { + ii1 = i / p.nei0; + } + uint ii0 = i - ii1 * p.nei0; + uint id = ids[iter++]; + uvec4 ballot = subgroupBallot(in_range && id == expert_idx); + + ballots_sh[gl_SubgroupID] = ballot; + barrier(); + + uint subgroup_base = 0; + uint total = 0; + for (uint k = 0; k < gl_NumSubgroups; ++k) { + if (k == gl_SubgroupID) { + subgroup_base = total; + } + total += subgroupBallotBitCount(ballots_sh[k]); + } + barrier(); + + uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot); + if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) { + row_ids[_ne1 + idx - ic * BN] = u16vec2(ii0, ii1); + } + _ne1 += total; + iter &= 15; + if (_ne1 >= (ic + 1) * BN) { + break; + } + } + barrier(); +} +#endif // MUL_MAT_ID_USE_SUBGROUPS +#endif // MUL_MAT_ID diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp index b5d761c0b..d955b4fc7 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp @@ -10,10 +10,9 @@ #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require #endif -#ifdef COOPMAT -#extension GL_KHR_cooperative_matrix : enable -#extension GL_KHR_memory_scope_semantics : enable +#if defined(MUL_MAT_ID_USE_SUBGROUPS) #extension GL_KHR_shader_subgroup_basic : enable +#extension GL_KHR_shader_subgroup_ballot : enable #endif #ifdef MUL_MAT_ID @@ -24,7 +23,10 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; -layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];}; +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +#if defined(A_TYPE_PACKED16) +layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];}; +#endif #if defined(A_TYPE_PACKED32) layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];}; #endif @@ -76,40 +78,31 @@ layout (constant_id = 10) const uint WARP = 32; #define BK 32 -#ifdef COOPMAT -#define SHMEM_STRIDE (BK / 4 + 4) -#else -#define SHMEM_STRIDE (BK / 4 + 1) -#endif +#define MMQ_SHMEM -shared int32_t buf_a_qs[BM * SHMEM_STRIDE]; - -#ifndef COOPMAT -#if QUANT_AUXF == 1 -shared FLOAT_TYPE buf_a_dm[BM]; -#else -shared FLOAT_TYPE_VEC2 buf_a_dm[BM]; -#endif -#endif - -shared int32_t buf_b_qs[BN * SHMEM_STRIDE]; -#ifndef COOPMAT -shared FLOAT_TYPE_VEC2 buf_b_ds[BN]; -#endif - -#define LOAD_VEC_A (4 * QUANT_R) -#define LOAD_VEC_B 16 +#include "mul_mmq_shmem_types.glsl" #ifdef MUL_MAT_ID -shared u16vec2 row_ids[4096]; -#endif // MUL_MAT_ID +#define BK_STEP 1 +#else +#ifndef BK_STEP +#define BK_STEP 4 +#endif +#endif + +// Shared memory cache +shared block_a_cache buf_a[BM * BK_STEP]; +shared block_b_cache buf_b[BN * BK_STEP]; +// Register cache +block_a_cache cache_a[WMITER * TM]; +block_b_cache cache_b; + +#define LOAD_VEC_A (4 * QUANT_R_MMQ) +#define LOAD_VEC_B 16 #define NUM_WARPS (BLOCK_SIZE / WARP) -#ifdef COOPMAT -shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS]; -#endif - +#include "mul_mm_id_funcs.glsl" #include "mul_mmq_funcs.glsl" void main() { @@ -139,26 +132,12 @@ void main() { const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER); const uint WSUBM = WM / WMITER; const uint WSUBN = WN / WNITER; - -#ifdef COOPMAT - const uint warp_i = gl_SubgroupID; - - const uint tiw = gl_SubgroupInvocationID; - - const uint cms_per_row = WM / TM; - const uint cms_per_col = WN / TN; - - const uint storestride = WARP / TM; - const uint store_r = tiw % TM; - const uint store_c = tiw / TM; -#else const uint warp_i = gl_LocalInvocationID.x / WARP; const uint tiw = gl_LocalInvocationID.x % WARP; const uint tiwr = tiw % (WSUBM / TM); const uint tiwc = tiw / (WSUBM / TM); -#endif const uint warp_r = warp_i % (BM / WM); const uint warp_c = warp_i / (BM / WM); @@ -172,17 +151,27 @@ void main() { const uint loadstride_b = BLOCK_SIZE * LOAD_VEC_B / BK; #ifdef MUL_MAT_ID - uint _ne1 = 0; - for (uint ii1 = 0; ii1 < p.nei1; ii1++) { - for (uint ii0 = 0; ii0 < p.nei0; ii0++) { +#ifdef MUL_MAT_ID_USE_SUBGROUPS + if (bitCount(p.nei0) == 1) { + load_row_ids(expert_idx, true, ic); + } else { + load_row_ids(expert_idx, false, ic); + } +#else + _ne1 = 0; + for (uint ii1 = 0; ii1 < p.nei1 && _ne1 < (ic + 1) * BN; ii1++) { + for (uint ii0 = 0; ii0 < p.nei0 && _ne1 < (ic + 1) * BN; ii0++) { if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) { - row_ids[_ne1] = u16vec2(ii0, ii1); + if (_ne1 >= ic * BN) { + row_ids[_ne1 - ic * BN] = u16vec2(ii0, ii1); + } _ne1++; } } } barrier(); +#endif // Workgroup has no work if (ic * BN >= _ne1) return; @@ -209,159 +198,70 @@ void main() { uint pos_b_ib = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / BK; #endif -#ifdef COOPMAT - coopmat cache_a; - coopmat cache_b; - coopmat cm_result; - - coopmat factors[cms_per_row * cms_per_col]; - - coopmat sums[cms_per_row * cms_per_col]; - - [[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) { - sums[i] = coopmat(0.0f); - } -#else - int32_t cache_a_qs[WMITER * TM * BK / 4]; - - int32_t cache_b_qs[TN * BK / 4]; - ACC_TYPE sums[WMITER * TM * WNITER * TN]; [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) { sums[i] = ACC_TYPE(0.0f); } -#endif -#if QUANT_AUXF == 1 - FLOAT_TYPE cache_a_dm[WMITER * TM]; -#else - FLOAT_TYPE_VEC2 cache_a_dm[WMITER * TM]; -#endif - - FLOAT_TYPE_VEC2 cache_b_ds[TN]; - - for (uint block = start_k; block < end_k; block += BK) { + for (uint block = start_k; block < end_k; block += BK * BK_STEP) { [[unroll]] for (uint l = 0; loadc_a + l < BM; l += loadstride_a) { - const uint ib = pos_a_ib + (loadc_a + l) * p.stride_a / BK; - const uint iqs = loadr_a; const uint buf_ib = loadc_a + l; + const uint ib = pos_a_ib + buf_ib * p.stride_a / BK; + const uint iqs = loadr_a; - if (iqs == 0) { -#if QUANT_AUXF == 1 - buf_a_dm[buf_ib] = get_d(ib); -#else - buf_a_dm[buf_ib] = get_dm(ib); -#endif + [[unroll]] for (uint k_step = 0; k_step < BK_STEP; k_step++) { + block_a_to_shmem(k_step * BM + buf_ib, ib + k_step, iqs); } -#if QUANT_R == 1 - buf_a_qs[buf_ib * SHMEM_STRIDE + iqs] = repack(ib, iqs); -#else - const i32vec2 vals = repack(ib, iqs); - buf_a_qs[buf_ib * SHMEM_STRIDE + iqs ] = vals.x; - buf_a_qs[buf_ib * SHMEM_STRIDE + iqs + 4] = vals.y; -#endif } [[unroll]] for (uint l = 0; loadc_b + l < BN; l += loadstride_b) { -#ifdef MUL_MAT_ID - const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l]; - const uint idx = pos_b_ib + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b; - const uint ib = idx / 8; - const uint iqs = idx & 0x7; -#else - const uint ib = pos_b_ib + (loadc_b + l) * p.stride_b / BK; - const uint ib_outer = ib / 4; - const uint ib_inner = ib % 4; - - const uint iqs = loadr_b; -#endif - const uint buf_ib = loadc_b + l; - if (iqs == 0) { - buf_b_ds[buf_ib] = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]); +#ifdef MUL_MAT_ID + const u16vec2 row_idx = row_ids[buf_ib]; + const uint ib = pos_b_ib + row_idx.y * p.batch_stride_b / BK + (row_idx.x % p.ne11) * p.stride_b / BK; +#else + const uint ib = pos_b_ib + buf_ib * p.stride_b / BK; +#endif + const uint iqs = loadr_b; + + [[unroll]] for (uint k_step = 0; k_step < BK_STEP; k_step++) { + block_b_to_shmem(k_step * BN + buf_ib, ib + k_step, iqs); } - const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs]; - buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 ] = values.x; - buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 1] = values.y; - buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 2] = values.z; - buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 3] = values.w; } barrier(); - pos_a_ib += 1; - pos_b_ib += 1; + pos_a_ib += BK_STEP; + pos_b_ib += BK_STEP; -#ifdef COOPMAT - [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { - const uint ib_a = warp_r * WM + cm_row * TM; + for (uint k_step = 0; k_step < BK_STEP; k_step++) { // Load from shared into cache - coopMatLoad(cache_a, buf_a_qs, ib_a * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor); - - // TODO: only cache values that are actually needed - [[unroll]] for (uint t_idx = 0; t_idx < TM; t_idx++) { - cache_a_dm[t_idx] = buf_a_dm[ib_a + t_idx]; - } - - [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { - const uint ib_b = warp_c * WN + cm_col * TN; - coopMatLoad(cache_b, buf_b_qs, ib_b * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor); - - // TODO: only cache values that are actually needed - [[unroll]] for (uint t_idx = 0; t_idx < TN; t_idx++) { - cache_b_dm[t_idx] = buf_b_d[ib_b + t_idx]; - } - - cm_result = coopmat(0); - cm_result = coopMatMulAdd(cache_a, cache_b, cm_result); - - [[unroll]] for (uint col = 0; col < TN; col += storestride) { - coopmat_stage[warp_i * TM * TN + (store_c + col) * TM + store_r] = ACC_TYPE(float(cache_a_d[store_r]) * float(cache_b_d[store_c + col])); - } - - coopMatLoad(factors, coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); - sums[cm_col * cms_per_row + cm_row] += factors * coopmat(cm_result); - } - } -#else - // Load from shared into cache - [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { - [[unroll]] for (uint cr = 0; cr < TM; cr++) { - const uint ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr; - cache_a_dm[wsir * TM + cr] = buf_a_dm[ib]; - [[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) { - cache_a_qs[(wsir * TM + cr) * (BK / 4) + idx_k] = buf_a_qs[ib * SHMEM_STRIDE + idx_k]; - } - } - } - - [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { - [[unroll]] for (uint cc = 0; cc < TN; cc++) { - const uint ib = warp_c * WN + wsic * WSUBN + tiwc * TN + cc; - cache_b_ds[cc] = buf_b_ds[ib]; - [[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) { - cache_b_qs[cc * (BK / 4) + idx_k] = buf_b_qs[ib * SHMEM_STRIDE + idx_k]; - } - } - [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { - [[unroll]] for (uint cc = 0; cc < TN; cc++) { - [[unroll]] for (uint cr = 0; cr < TM; cr++) { - const uint cache_a_idx = wsir * TM + cr; - const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr; - int32_t q_sum = 0; - [[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) { - q_sum += dotPacked4x8EXT(cache_a_qs[cache_a_idx * (BK / 4) + idx_k], - cache_b_qs[cc * (BK / 4) + idx_k]); - } + [[unroll]] for (uint cr = 0; cr < TM; cr++) { + const uint reg_ib = wsir * TM + cr; + const uint buf_ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr; - sums[sums_idx] += mul_q8_1(q_sum, cache_a_dm[cache_a_idx], cache_b_ds[cc], 1); + block_a_to_registers(reg_ib, k_step * BM + buf_ib); + } + } + + [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { + [[unroll]] for (uint cc = 0; cc < TN; cc++) { + const uint ib = k_step * BN + warp_c * WN + wsic * WSUBN + tiwc * TN + cc; + block_b_to_registers(ib); + + [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { + [[unroll]] for (uint cr = 0; cr < TM; cr++) { + const uint cache_a_idx = wsir * TM + cr; + const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr; + + sums[sums_idx] += mmq_dot_product(cache_a_idx); + } } } } } -#endif barrier(); } @@ -373,54 +273,6 @@ void main() { const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z; #endif -#ifdef COOPMAT -#ifdef MUL_MAT_ID - [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { - [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { - coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); - - [[unroll]] for (uint col = 0; col < BN; col += storestride) { - const uint row_i = dc + cm_col * TN + col + store_c; - if (row_i >= _ne1) break; - - const u16vec2 row_idx = row_ids[row_i]; - - data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); - } - } - } -#else - const bool is_aligned = p.stride_d % 4 == 0; // Assumption: D_TYPE == float - - [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { - [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { - const bool is_in_bounds = dr + (cm_row + 1) * TM <= p.M && dc + (cm_col + 1) * TN <= p.N; - - if (is_aligned && is_in_bounds) { - // Full coopMat is within bounds and stride_d is aligned with 16B - coopmat cm_dtype = coopmat(sums[cm_col * cms_per_row + cm_row]); - coopMatStore(cm_dtype, data_d, offsets + (dc + cm_col * TN) * p.stride_d + dr + cm_row * TM, p.stride_d, gl_CooperativeMatrixLayoutColumnMajor); - } else if (is_in_bounds) { - // Full coopMat is within bounds, but stride_d is not aligned - coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); - - [[unroll]] for (uint col = 0; col < TN; col += storestride) { - data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); - } - } else if (dr + cm_row * TM < p.M && dc + cm_col * TN < p.N) { - // Partial coopMat is within bounds - coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); - - [[unroll]] for (uint col = 0; col < TN; col += storestride) { - if (dr + cm_row * TM + store_r < p.M && dc + cm_col * TN + col + store_c < p.N) { - data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); - } - } - } - } - } -#endif // MUL_MAT_ID -#else [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { @@ -431,19 +283,21 @@ void main() { const uint row_i = dc_warp + cc; if (row_i >= _ne1) break; - const u16vec2 row_idx = row_ids[row_i]; + const u16vec2 row_idx = row_ids[row_i - ic * BN]; #endif // MUL_MAT_ID [[unroll]] for (uint cr = 0; cr < TM; cr++) { + const uint sums_idx = (wsic * TN + cc) * WMITER * TM + wsir * TM + cr; #ifdef MUL_MAT_ID - data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); + if (dr_warp + cr < p.M) { + data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[sums_idx].x); + } #else if (dr_warp + cr < p.M && dc_warp + cc < p.N) { - data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); + data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[sums_idx].x); } #endif // MUL_MAT_ID } } } } -#endif // COOPMAT } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl index fe71eb131..c0c03fedc 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl @@ -6,41 +6,89 @@ // Each iqs value maps to a 32-bit integer -#if defined(DATA_A_Q4_0) +#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) +// 2-byte loads for Q4_0 blocks (18 bytes) +// 4-byte loads for Q4_1 blocks (20 bytes) i32vec2 repack(uint ib, uint iqs) { - // Use 2-byte loads since a q4_0 block (18 bytes) is not divisible by 4 - const u16vec2 quants = u16vec2(data_a[ib].qs[iqs * 2 ], - data_a[ib].qs[iqs * 2 + 1]); +#ifdef DATA_A_Q4_0 + const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ], + data_a_packed16[ib].qs[iqs * 2 + 1]); const uint32_t vui = pack32(quants); return i32vec2( vui & 0x0F0F0F0F, (vui >> 4) & 0x0F0F0F0F); +#else // DATA_A_Q4_1 + const uint32_t vui = data_a_packed32[ib].qs[iqs]; + return i32vec2( vui & 0x0F0F0F0F, + (vui >> 4) & 0x0F0F0F0F); +#endif } +#ifdef DATA_A_Q4_0 ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { return ACC_TYPE(da * (float(q_sum) * dsb.x - (8 / sum_divisor) * dsb.y)); } -#endif - -#if defined(DATA_A_Q4_1) -i32vec2 repack(uint ib, uint iqs) { - // Use 4-byte loads since a q4_1 block (20 bytes) is divisible by 4 - const uint32_t vui = data_a_packed32[ib].qs[iqs]; - return i32vec2( vui & 0x0F0F0F0F, - (vui >> 4) & 0x0F0F0F0F); -} - +#else // DATA_A_Q4_1 ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) { return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor); } #endif -#if defined(DATA_A_Q5_0) +#ifdef MMQ_SHMEM +void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { +#ifdef DATA_A_Q4_0 + buf_a[buf_ib].qs[iqs] = pack32(u16vec2(data_a_packed16[ib].qs[iqs * 2], + data_a_packed16[ib].qs[iqs * 2 + 1])); + + if (iqs == 0) { + buf_a[buf_ib].dm = FLOAT_TYPE(data_a_packed16[ib].d); + } +#else // DATA_A_Q4_1 + buf_a[buf_ib].qs[iqs] = data_a_packed32[ib].qs[iqs]; + + if (iqs == 0) { + buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib].dm); + } +#endif +} + +void block_a_to_registers(const uint reg_ib, const uint buf_ib) { + cache_a[reg_ib].dm = buf_a[buf_ib].dm; + + [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) { + cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs]; + } +} + +ACC_TYPE mmq_dot_product(const uint ib_a) { + int32_t q_sum = 0; + [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) { + const uint32_t vui = cache_a[ib_a].qs[iqs]; + const i32vec2 qs_a = i32vec2( vui & 0x0F0F0F0F, + (vui >> 4) & 0x0F0F0F0F); + + const int32_t qs_b0 = cache_b.qs[iqs]; + const int32_t qs_b1 = cache_b.qs[iqs + 4]; + + q_sum += dotPacked4x8EXT(qs_a.x, qs_b0); + q_sum += dotPacked4x8EXT(qs_a.y, qs_b1); + } + + return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1); +} +#endif // MMQ_SHMEM + +#elif defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) +// 2-byte loads for Q5_0 blocks (22 bytes) +// 4-byte loads for Q5_1 blocks (24 bytes) i32vec2 repack(uint ib, uint iqs) { - // Use 2-byte loads since a q5_0 block (22 bytes) is not divisible by 4 - const u16vec2 quants = u16vec2(data_a[ib].qs[iqs * 2 ], - data_a[ib].qs[iqs * 2 + 1]); + const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ], + data_a_packed16[ib].qs[iqs * 2 + 1]); const uint32_t vui = pack32(quants); - const int32_t qh = int32_t((uint32_t(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0]) >> (4 * iqs)); +#ifdef DATA_A_Q5_0 + const int32_t qh = int32_t((uint32_t(data_a_packed16[ib].qh[1]) << 16 | data_a_packed16[ib].qh[0]) >> (4 * iqs)); +#else // DATA_A_Q5_1 + const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs)); +#endif const int32_t v0 = int32_t(vui & 0x0F0F0F0F) | ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28) @@ -50,40 +98,457 @@ i32vec2 repack(uint ib, uint iqs) { return i32vec2(v0, v1); } +#ifdef DATA_A_Q5_0 ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { return ACC_TYPE(da * (float(q_sum) * dsb.x - (16 / sum_divisor) * dsb.y)); } -#endif - -#if defined(DATA_A_Q5_1) -i32vec2 repack(uint ib, uint iqs) { - // Use 4-byte loads since a q5_1 block (24 bytes) is divisible by 4 - const uint32_t vui = data_a_packed32[ib].qs[iqs]; - const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs)); - const int32_t v0 = int32_t(vui & 0x0F0F0F0F) - | ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28) - - const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F) - | (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28) - - return i32vec2(v0, v1); -} - +#else // DATA_A_Q5_1 ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) { return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor); } #endif +#ifdef MMQ_SHMEM +void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { +#ifdef DATA_A_Q5_0 + buf_a[buf_ib].qs[iqs] = pack32(u16vec2(data_a_packed16[ib].qs[iqs * 2], + data_a_packed16[ib].qs[iqs * 2 + 1])); + + if (iqs == 0) { + buf_a[buf_ib].dm = FLOAT_TYPE(data_a_packed16[ib].d); + buf_a[buf_ib].qh = pack32(u16vec2(data_a_packed16[ib].qh[0], data_a_packed16[ib].qh[1])); + } +#else // DATA_A_Q5_1 + buf_a[buf_ib].qs[iqs] = data_a_packed32[ib].qs[iqs]; + + if (iqs == 0) { + buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib].dm); + buf_a[buf_ib].qh = data_a_packed32[ib].qh; + } +#endif +} + +void block_a_to_registers(const uint reg_ib, const uint buf_ib) { + cache_a[reg_ib].dm = buf_a[buf_ib].dm; + cache_a[reg_ib].qh = buf_a[buf_ib].qh; + + [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) { + cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs]; + } +} + +ACC_TYPE mmq_dot_product(const uint ib_a) { + int32_t q_sum = 0; + [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) { + const uint32_t vui = cache_a[ib_a].qs[iqs]; + const int32_t qh = int32_t(cache_a[ib_a].qh >> (4 * iqs)); + const int32_t qs_a0 = int32_t(vui & 0x0F0F0F0F) + | ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28) + const int32_t qs_a1 = int32_t((vui >> 4) & 0x0F0F0F0F) + | (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28) + + const int32_t qs_b0 = cache_b.qs[iqs]; + const int32_t qs_b1 = cache_b.qs[iqs + 4]; + + q_sum += dotPacked4x8EXT(qs_a0, qs_b0); + q_sum += dotPacked4x8EXT(qs_a1, qs_b1); + } + + return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1); +} +#endif // MMQ_SHMEM +#endif + #if defined(DATA_A_Q8_0) +// 2-byte loads for Q8_0 blocks (34 bytes) int32_t repack(uint ib, uint iqs) { - // Use 2-byte loads since a q8_0 block (34 bytes) is not divisible by 4 - return pack32(i16vec2(data_a[ib].qs[iqs * 2 ], - data_a[ib].qs[iqs * 2 + 1])); + return pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2 ], + data_a_packed16[ib].qs[iqs * 2 + 1])); } ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { return ACC_TYPE(float(q_sum) * da * dsb.x); } + +#ifdef MMQ_SHMEM +void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { + buf_a[buf_ib].qs[iqs] = pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2], + data_a_packed16[ib].qs[iqs * 2 + 1])); + + if (iqs == 0) { + buf_a[buf_ib].dm = FLOAT_TYPE(data_a_packed16[ib].d); + } +} + +void block_a_to_registers(const uint reg_ib, const uint buf_ib) { + cache_a[reg_ib].dm = buf_a[buf_ib].dm; + + [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) { + cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs]; + } +} + +ACC_TYPE mmq_dot_product(const uint ib_a) { + int32_t q_sum = 0; + [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) { + const int32_t qs_a = cache_a[ib_a].qs[iqs]; + const int32_t qs_b = cache_b.qs[iqs]; + + q_sum += dotPacked4x8EXT(qs_a, qs_b); + } + + return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1); +} +#endif // MMQ_SHMEM +#endif + +#if defined(DATA_A_MXFP4) +// 1-byte loads for mxfp4 blocks (17 bytes) +i32vec2 repack(uint ib, uint iqs) { + const uint32_t quants = pack32(u8vec4(data_a[ib].qs[iqs * 4 ], + data_a[ib].qs[iqs * 4 + 1], + data_a[ib].qs[iqs * 4 + 2], + data_a[ib].qs[iqs * 4 + 3])); + + return i32vec2( quants & 0x0F0F0F0F, + (quants >> 4) & 0x0F0F0F0F); +} + +ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { + return ACC_TYPE(da * dsb.x * float(q_sum)); +} + +#ifdef MMQ_SHMEM +void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { + const uint32_t qs = pack32(u8vec4(data_a[ib].qs[iqs * 4 ], + data_a[ib].qs[iqs * 4 + 1], + data_a[ib].qs[iqs * 4 + 2], + data_a[ib].qs[iqs * 4 + 3])); + + const u8vec4 i_a0 = unpack8( qs & 0x0F0F0F0F); + const u8vec4 i_a1 = unpack8((qs >> 4) & 0x0F0F0F0F); + + buf_a[buf_ib].qs[iqs ] = pack32(i8vec4(kvalues_mxfp4[i_a0.x], kvalues_mxfp4[i_a0.y], kvalues_mxfp4[i_a0.z], kvalues_mxfp4[i_a0.w])); + buf_a[buf_ib].qs[iqs + 4] = pack32(i8vec4(kvalues_mxfp4[i_a1.x], kvalues_mxfp4[i_a1.y], kvalues_mxfp4[i_a1.z], kvalues_mxfp4[i_a1.w])); + + if (iqs == 0) { + buf_a[buf_ib].d = FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e) * 0.5); + } +} + +void block_a_to_registers(const uint reg_ib, const uint buf_ib) { + cache_a[reg_ib].d = buf_a[buf_ib].d; + + [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) { + cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs]; + } +} + +ACC_TYPE mmq_dot_product(const uint ib_a) { + int32_t q_sum = 0; + [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) { + const int32_t qs_a = cache_a[ib_a].qs[iqs]; + + q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]); + } + + return mul_q8_1(q_sum, cache_a[ib_a].d, cache_b.ds, 1); +} +#endif // MMQ_SHMEM +#endif + +// For k-quants, ib and iqs still assume 32-wide blocks, but k-quants are 256-wide +// iqs still refers to a 32-bit integer, meaning 0..7 for 32-wide quants +#if defined(DATA_A_Q2_K) +// 4-byte loads for Q2_K blocks (84 bytes) +int32_t repack(uint ib, uint iqs) { + const uint ib_k = ib / 8; + const uint iqs_k = (ib % 8) * 8 + iqs; + + const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8); + const uint qs_shift = ((iqs_k % 32) / 8) * 2; + + return int32_t((data_a_packed32[ib_k].qs[qs_idx] >> qs_shift) & 0x03030303); +} + +uint8_t get_scale(uint ib, uint iqs) { + const uint ib_k = ib / 8; + const uint iqs_k = (ib % 8) * 8 + iqs; + + return data_a[ib_k].scales[iqs_k / 4]; +} + +ACC_TYPE mul_q8_1(const int32_t sum_d, const int32_t sum_m, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) { + return ACC_TYPE(dsb.x * (dma.x * float(sum_d) - dma.y * float(sum_m))); +} + +#ifdef MMQ_SHMEM +void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { + const uint ib_k = ib / 8; + const uint iqs_k = (ib % 8) * 8 + iqs * QUANT_R_MMQ; + + const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8); + const uint qs_shift = ((iqs_k % 32) / 8) * 2; + + // Repack 4x4 quants into one int + const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x03030303; + const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x03030303; + const uint32_t vals2 = (data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x03030303; + const uint32_t vals3 = (data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x03030303; + + buf_a[buf_ib].qs[iqs] = vals0 | (vals1 << 2) | (vals2 << 4) | (vals3 << 6); + + if (iqs == 0) { + buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm); + buf_a[buf_ib].scales = unpack8(data_a_packed16[ib_k].scales[iqs_k / 8]); + } +} + +void block_a_to_registers(const uint reg_ib, const uint buf_ib) { + cache_a[reg_ib].dm = buf_a[buf_ib].dm; + cache_a[reg_ib].scales = buf_a[buf_ib].scales; + + [[unroll]] for (uint iqs = 0; iqs < 2; iqs++) { + cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs]; + } +} + +ACC_TYPE mmq_dot_product(const uint ib_a) { + int32_t sum_d = 0; + int32_t sum_m = 0; + + [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) { + const uint8_t scale = cache_a[ib_a].scales[iqs / 4]; + const int32_t scale_m = int32_t(scale >> 4) * 0x01010101; // Duplicate 8-bit value across 32-bits. + const int32_t qs_a = int32_t((cache_a[ib_a].qs[iqs / 4] >> ((iqs % 4) * 2)) & 0x03030303); + + sum_d += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]) * (scale & 0xF); + sum_m += dotPacked4x8EXT(scale_m, cache_b.qs[iqs]); + } + + return mul_q8_1(sum_d, sum_m, cache_a[ib_a].dm, cache_b.ds, 1); +} +#endif // MMQ_SHMEM +#endif + +#if defined(DATA_A_Q3_K) +// 2-byte loads for Q3_K blocks (110 bytes) +#ifdef MMQ_SHMEM +void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { + const uint ib_k = ib / 8; + const uint hm_idx = iqs * QUANT_R_MMQ; + const uint iqs_k = (ib % 8) * 8 + hm_idx; + + const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8); + const uint qs_shift = ((iqs_k % 32) / 8) * 2; + const uint hm_shift = iqs_k / 8; + + // Repack 2x4 quants into one int + // Add the 3rd bit instead of subtracting it to allow packing the quants + const i8vec2 vals00 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 ] >> qs_shift) & uint16_t(0x0303))) | + unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 ] >> hm_shift) & uint16_t(0x0101)) << 2)); + const i8vec2 vals01 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 1 ] >> qs_shift) & uint16_t(0x0303))) | + unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 1] >> hm_shift) & uint16_t(0x0101)) << 2)); + const i8vec2 vals10 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 2 ] >> qs_shift) & uint16_t(0x0303))) | + unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 2] >> hm_shift) & uint16_t(0x0101)) << 2)); + const i8vec2 vals11 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 3 ] >> qs_shift) & uint16_t(0x0303))) | + unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 3] >> hm_shift) & uint16_t(0x0101)) << 2)); + buf_a[buf_ib].qs[iqs] = pack32(u8vec4(vals00.x, vals00.y, vals01.x, vals01.y)) | + (pack32(u8vec4(vals10.x, vals10.y, vals11.x, vals11.y)) << 4); + + if (iqs == 0) { + const uint is = iqs_k / 4; + const i8vec2 scales = i8vec2(unpack8(((data_a_packed16[ib_k].scales[(is % 8 ) / 2] >> (4 * (is / 8))) & 0x0F0F) | + (((data_a_packed16[ib_k].scales[(8 + (is % 4)) / 2] >> (2 * (is / 4))) & 0x0303) << 4))); + + buf_a[buf_ib].d_scales = FLOAT_TYPE(data_a_packed16[ib_k].d) * FLOAT_TYPE_VEC2(scales - 32); + } +} + +void block_a_to_registers(const uint reg_ib, const uint buf_ib) { + cache_a[reg_ib].d_scales = buf_a[buf_ib].d_scales; + + [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) { + cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs]; + } +} + +ACC_TYPE mmq_dot_product(const uint ib_a) { + float result = 0.0; + int32_t q_sum = 0; + + [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) { + // Subtract 4 from the quants to correct the 3rd bit offset + const int32_t qs_a = pack32(unpack8(int32_t((cache_a[ib_a].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F)) - int8_t(4)); + + q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]); + } + result += float(cache_a[ib_a].d_scales[0]) * float(q_sum); + q_sum = 0; + + [[unroll]] for (uint iqs = 4; iqs < 8; iqs++) { + const int32_t qs_a = pack32(unpack8(int32_t((cache_a[ib_a].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F)) - int8_t(4)); + + q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]); + } + result += float(cache_a[ib_a].d_scales[1]) * float(q_sum); + + return ACC_TYPE(cache_b.ds.x * result); +} +#endif // MMQ_SHMEM +#endif + +#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K) +// 4-byte loads for Q4_K blocks (144 bytes) and Q5_K blocks (176 bytes) +ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) { + return ACC_TYPE(dsb.x * dma.x * float(q_sum) - dma.y * dsb.y); +} + +#ifdef MMQ_SHMEM +void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { + const uint ib_k = ib / 8; + const uint iqs_k = (ib % 8) * 8 + iqs * QUANT_R_MMQ; + + const uint qs_idx = (iqs_k / 16) * 8 + (iqs_k % 8); + const uint qs_shift = ((iqs_k % 16) / 8) * 4; + + // Repack 2x4 quants into one int +#if defined(DATA_A_Q4_K) + const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x0F0F0F0F; + const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x0F0F0F0F; + + buf_a[buf_ib].qs[iqs] = vals0 | (vals1 << 4); +#else // defined(DATA_A_Q5_K) + const uint qh_idx = iqs * QUANT_R_MMQ; + const uint qh_shift = iqs_k / 8; + + buf_a[buf_ib].qs[iqs] = int32_t(((data_a_packed32[ib_k].qs[qs_idx] >> qs_shift) & 0x0F0F0F0F) | + (((data_a_packed32[ib_k].qh[qh_idx] >> qh_shift) & 0x01010101) << 4)); +#endif + + + if (iqs == 0) { + // Scale index + const uint is = iqs_k / 8; + u8vec2 scale_dm; + if (is < 4) { + scale_dm = u8vec2(data_a[ib_k].scales[is] & 0x3F, data_a[ib_k].scales[is + 4] & 0x3F); + } else { + scale_dm = u8vec2((data_a[ib_k].scales[is+4] & 0xF) | ((data_a[ib_k].scales[is-4] & 0xC0) >> 2), + (data_a[ib_k].scales[is+4] >> 4) | ((data_a[ib_k].scales[is ] & 0xC0) >> 2)); + } + + buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm) * FLOAT_TYPE_VEC2(scale_dm); + } +} + +void block_a_to_registers(const uint reg_ib, const uint buf_ib) { + cache_a[reg_ib].dm = buf_a[buf_ib].dm; + + [[unroll]] for (uint iqs = 0; iqs < 8 / QUANT_R_MMQ; iqs++) { + cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs]; + } +} + +ACC_TYPE mmq_dot_product(const uint ib_a) { + int32_t q_sum = 0; + + [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) { +#if defined(DATA_A_Q4_K) + const int32_t qs_a = int32_t((cache_a[ib_a].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F); +#else // defined(DATA_A_Q5_K) + const int32_t qs_a = cache_a[ib_a].qs[iqs]; +#endif + + q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]); + } + + return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1); +} +#endif // MMQ_SHMEM +#endif + +#ifdef MMQ_SHMEM +void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { + const uint ib_outer = ib / 4; + const uint ib_inner = ib % 4; + + if (iqs == 0) { + buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]); + } + + const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs]; + buf_b[buf_ib].qs[iqs * 4 ] = values.x; + buf_b[buf_ib].qs[iqs * 4 + 1] = values.y; + buf_b[buf_ib].qs[iqs * 4 + 2] = values.z; + buf_b[buf_ib].qs[iqs * 4 + 3] = values.w; +} + +void block_b_to_registers(const uint ib) { + cache_b.ds = buf_b[ib].ds; + [[unroll]] for (uint iqs = 0; iqs < BK / 4; iqs++) { + cache_b.qs[iqs] = buf_b[ib].qs[iqs]; + } +} +#endif + +#if defined(DATA_A_Q6_K) +// 2-byte loads for Q6_K blocks (210 bytes) +#ifdef MMQ_SHMEM +void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { + const uint ib_k = ib / 8; + const uint iqs_k = (ib % 8) * 8 + iqs; + + const uint ql_idx = (iqs_k / 32) * 16 + iqs_k % 16; + const uint ql_shift = ((iqs_k % 32) / 16) * 4; + + const uint qh_idx = (iqs_k / 32) * 8 + iqs; + const uint qh_shift = ((iqs_k % 32) / 8) * 2; + + const i8vec2 vals00 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 ] >> ql_shift) & uint16_t(0x0F0F))) | + unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 ] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); + const i8vec2 vals01 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 1] >> ql_shift) & uint16_t(0x0F0F))) | + unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 1] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); + buf_a[buf_ib].qs[iqs] = pack32(i8vec4(vals00.x, vals00.y, vals01.x, vals01.y)); + + if (iqs == 0) { + const uint is = iqs_k / 4; + const i8vec2 scales = unpack8(data_a_packed16[ib_k].scales[is / 2]); + + buf_a[buf_ib].d_scales = FLOAT_TYPE(data_a_packed16[ib_k].d) * FLOAT_TYPE_VEC2(scales); + } +} + +void block_a_to_registers(const uint reg_ib, const uint buf_ib) { + cache_a[reg_ib].d_scales = buf_a[buf_ib].d_scales; + + [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) { + cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs]; + } +} + +ACC_TYPE mmq_dot_product(const uint ib_a) { + float result = 0.0; + int32_t q_sum = 0; + + [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) { + const int32_t qs_a = cache_a[ib_a].qs[iqs]; + + q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]); + } + result += float(cache_a[ib_a].d_scales[0]) * float(q_sum); + q_sum = 0; + + [[unroll]] for (uint iqs = 4; iqs < 8; iqs++) { + const int32_t qs_a = cache_a[ib_a].qs[iqs]; + + q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]); + } + result += float(cache_a[ib_a].d_scales[1]) * float(q_sum); + + return ACC_TYPE(cache_b.ds.x * result); +} +#endif // MMQ_SHMEM #endif #if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL) @@ -103,3 +568,10 @@ FLOAT_TYPE_VEC2 get_dm(uint ib) { return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm); } #endif + +#if defined(DATA_A_Q2_K) +FLOAT_TYPE_VEC2 get_dm(uint ib) { + const uint ib_k = ib / 8; + return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm); +} +#endif diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl new file mode 100644 index 000000000..1c0f5306f --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl @@ -0,0 +1,78 @@ +#if defined(DATA_A_Q4_0) +#define QUANT_R_MMQ 2 +struct block_a_cache { + uint32_t qs[16/4]; + FLOAT_TYPE dm; +}; +#elif defined(DATA_A_Q4_1) +#define QUANT_R_MMQ 2 +struct block_a_cache { + uint32_t qs[16/4]; + FLOAT_TYPE_VEC2 dm; +}; +#elif defined(DATA_A_Q5_0) +#define QUANT_R_MMQ 2 +struct block_a_cache { + uint32_t qs[16/4]; + uint32_t qh; + FLOAT_TYPE dm; +}; +#elif defined(DATA_A_Q5_1) +#define QUANT_R_MMQ 2 +struct block_a_cache { + uint32_t qs[16/4]; + uint32_t qh; + FLOAT_TYPE_VEC2 dm; +}; +#elif defined(DATA_A_Q8_0) +#define QUANT_R_MMQ 1 +// AMD likes 4, Intel likes 1 and Nvidia likes 2 +// #define BK_STEP 1 +struct block_a_cache { + int32_t qs[32/4]; + FLOAT_TYPE dm; +}; +#elif defined(DATA_A_MXFP4) +#define QUANT_R_MMQ 2 +struct block_a_cache { + int32_t qs[8]; + FLOAT_TYPE d; +}; +#elif defined(DATA_A_Q2_K) +#define QUANT_R_MMQ 4 +struct block_a_cache { + uint32_t qs[2]; + u8vec2 scales; + FLOAT_TYPE_VEC2 dm; +}; +#elif defined(DATA_A_Q3_K) +#define QUANT_R_MMQ 2 +struct block_a_cache { + uint32_t qs[4]; + FLOAT_TYPE_VEC2 d_scales; +}; +#elif defined(DATA_A_Q4_K) +#define QUANT_R_MMQ 2 +struct block_a_cache { + uint32_t qs[4]; + FLOAT_TYPE_VEC2 dm; +}; +#elif defined(DATA_A_Q5_K) +#define QUANT_R_MMQ 1 +struct block_a_cache { + int32_t qs[8]; + FLOAT_TYPE_VEC2 dm; +}; +#elif defined(DATA_A_Q6_K) +#define QUANT_R_MMQ 1 +struct block_a_cache { + int32_t qs[8]; + FLOAT_TYPE_VEC2 d_scales; +}; +#endif + +struct block_b_cache +{ + int32_t qs[8]; + FLOAT_TYPE_VEC2 ds; +}; diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl index 50fc1f1e2..0eda186c8 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl @@ -10,6 +10,7 @@ layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; layout (binding = 1) readonly buffer Y {int data_pos[];}; layout (binding = 2) readonly buffer Z {float data_ff[];}; layout (binding = 3) writeonly buffer D {D_TYPE data_d[];}; +layout (binding = 4) readonly buffer I {uvec2 data_i[];}; // indices for set_rows layout (push_constant) uniform parameter { uint ncols; @@ -27,6 +28,7 @@ layout (push_constant) uniform parameter { uint s2; int sections[4]; uint is_back; + uint set_rows_stride; } p; float rope_yarn_ramp(const float low, const float high, const uint i0) { diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp index 06e095bef..9f4538155 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp @@ -16,12 +16,19 @@ void main() { const uint row_x = row_dst % ne1; const uint channel_x = row_dst / ne1; - const uint idst = row_dst*ne0 + i0/2; + uint idst = row_dst*ne0 + i0/2; const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2; + // Fusion optimization: ROPE + VIEW + SET_ROWS.. + // The rope output is viewed as a 1D tensor and offset based on a row index in data_i. + if (p.set_rows_stride != 0) { + idst = row_x*ne0 + i0/2; + idst += data_i[channel_x].x * p.set_rows_stride; + } + if (i0 >= p.n_dims) { - data_d[idst + i0/2 + 0] = data_a[ix + i0/2 + 0]; - data_d[idst + i0/2 + 1] = data_a[ix + i0/2 + 1]; + data_d[idst + i0/2 + 0] = D_TYPE(data_a[ix + i0/2 + 0]); + data_d[idst + i0/2 + 1] = D_TYPE(data_a[ix + i0/2 + 1]); return; } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp index 6ba957540..f4209ed95 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp @@ -16,12 +16,19 @@ void main() { const uint row_x = row_dst % ne1; const uint channel_x = row_dst / ne1; - const uint idst = row_dst*ne0 + i0; + uint idst = row_dst*ne0 + i0; const uint ix = channel_x*p.s2 + row_x*p.s1 + i0; + // Fusion optimization: ROPE + VIEW + SET_ROWS.. + // The rope output is viewed as a 1D tensor and offset based on a row index in data_i. + if (p.set_rows_stride != 0) { + idst = row_x*ne0 + i0; + idst += data_i[channel_x].x * p.set_rows_stride; + } + if (i0 >= p.n_dims) { - data_d[idst + 0] = data_a[ix + 0]; - data_d[idst + 1] = data_a[ix + 1]; + data_d[idst + 0] = D_TYPE(data_a[ix + 0]); + data_d[idst + 1] = D_TYPE(data_a[ix + 1]); return; } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp index 9e56d5f8a..bc1c278bf 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp @@ -11,6 +11,8 @@ layout (push_constant) uniform parameter { uint n_rows; uint n_expert_used; + float clamp_min; + float clamp_max; }; layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in; @@ -18,6 +20,7 @@ layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in; layout(constant_id = 0) const uint WARP_SIZE = 32; layout(constant_id = 1) const uint n_experts = 512; layout(constant_id = 2) const bool with_norm = true; +layout(constant_id = 3) const bool late_softmax = false; const uint experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1; @@ -25,6 +28,52 @@ layout (binding = 0, std430) readonly buffer Logits {float logits[];}; layout (binding = 1, std430) writeonly buffer Weights {float weights[];}; layout (binding = 2, std430) writeonly buffer Ids {uint ids[];}; +const float INFINITY = 1.0 / 0.0; + +// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path. +void softmax_warp_inplace(inout float vals[experts_per_thread], const uint limit, const uint lane, const bool use_limit) { + float max_val = -INFINITY; + + [[unroll]] + for (int i = 0; i < experts_per_thread; i++) { + const uint idx = lane + i * WARP_SIZE; + const bool is_active = !use_limit || (idx < limit); + if (is_active) { + max_val = max(max_val, vals[i]); + } + } + + max_val = subgroupMax(max_val); + + float sum = 0.f; + + [[unroll]] + for (int i = 0; i < experts_per_thread; i++) { + const uint idx = lane + i * WARP_SIZE; + const bool is_active = !use_limit || (idx < limit); + if (is_active) { + const float val = exp(vals[i] - max_val); + vals[i] = val; + sum += val; + } else { + vals[i] = 0.f; + } + } + + sum = subgroupAdd(sum); + + const float inv_sum = 1.0f / sum; + + [[unroll]] + for (int i = 0; i < experts_per_thread; i++) { + const uint idx = lane + i * WARP_SIZE; + const bool is_active = !use_limit || (idx < limit); + if (is_active) { + vals[i] *= inv_sum; + } + } +} + void main() { const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_LocalInvocationID.y; if (row >= n_rows) { @@ -35,43 +84,16 @@ void main() { const uint weights_offset = n_expert_used * row; const uint ids_offset = n_experts * row; - float logits_r[experts_per_thread]; - - const float INFINITY = 1.0 / 0.0; + float wt[experts_per_thread]; [[unroll]] for (uint i = 0; i < n_experts; i += WARP_SIZE) { - const uint expert = i + gl_LocalInvocationID.x; - logits_r[i / WARP_SIZE] = n_experts % WARP_SIZE == 0 || expert < n_experts ? logits[logits_offset + expert] : -INFINITY; + const uint expert = i + gl_LocalInvocationID.x; + wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[logits_offset + expert] : -INFINITY; } - float max_val = logits_r[0]; - - [[unroll]] - for (int i = 1; i < experts_per_thread; i++) { - const float val = logits_r[i]; - max_val = max(val, max_val); - } - - max_val = subgroupMax(max_val); - - float wt[experts_per_thread]; - float tmp = 0.f; - - [[unroll]] - for (int i = 0; i < experts_per_thread; i++) { - const float val = logits_r[i]; - wt[i] = exp(val - max_val); - tmp += wt[i]; - } - - tmp = subgroupAdd(tmp); - - const float inv_sum = 1.0f / tmp; - - [[unroll]] - for (int i = 0; i < experts_per_thread; i++) { - wt[i] = wt[i] * inv_sum; + if (!late_softmax) { + softmax_warp_inplace(wt, n_experts, gl_LocalInvocationID.x, false); } // at this point, each thread holds a portion of softmax, @@ -82,6 +104,11 @@ void main() { float output_weights[experts_per_thread]; + [[unroll]] + for (int i = 0; i < experts_per_thread; i++) { + output_weights[i] = 0.f; + } + for (int k = 0; k < n_expert_used; k++) { float max_val = wt[0]; uint max_expert = gl_LocalInvocationID.x; @@ -121,6 +148,7 @@ void main() { if (with_norm) { wt_sum = subgroupAdd(wt_sum); + wt_sum = clamp(wt_sum, clamp_min, clamp_max); const float inv_sum = 1.0f / wt_sum; [[unroll]] @@ -129,6 +157,10 @@ void main() { } } + if (late_softmax) { + softmax_warp_inplace(output_weights, n_expert_used, gl_LocalInvocationID.x, true); + } + [[unroll]] for (uint i = 0; i < experts_per_thread; ++i) { uint idx = i * WARP_SIZE + gl_LocalInvocationID.x; diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl index 2fa54ce51..02578c77c 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl @@ -66,6 +66,7 @@ struct block_q4_0_packed16 #define QUANT_AUXF 1 #define A_TYPE block_q4_0 #define A_TYPE_PACKED16 block_q4_0_packed16 +#define DATA_A_QUANT_LEGACY #endif #define QUANT_K_Q4_1 32 @@ -98,6 +99,7 @@ struct block_q4_1_packed32 #define A_TYPE block_q4_1 #define A_TYPE_PACKED16 block_q4_1_packed16 #define A_TYPE_PACKED32 block_q4_1_packed32 +#define DATA_A_QUANT_LEGACY #endif #define QUANT_K_Q5_0 32 @@ -123,6 +125,7 @@ struct block_q5_0_packed16 #define QUANT_AUXF 1 #define A_TYPE block_q5_0 #define A_TYPE_PACKED16 block_q5_0_packed16 +#define DATA_A_QUANT_LEGACY #endif #define QUANT_K_Q5_1 32 @@ -158,6 +161,7 @@ struct block_q5_1_packed32 #define A_TYPE block_q5_1 #define A_TYPE_PACKED16 block_q5_1_packed16 #define A_TYPE_PACKED32 block_q5_1_packed32 +#define DATA_A_QUANT_LEGACY #endif #define QUANT_K_Q8_0 32 @@ -186,6 +190,7 @@ struct block_q8_0_packed32 #define A_TYPE block_q8_0 #define A_TYPE_PACKED16 block_q8_0_packed16 #define A_TYPE_PACKED32 block_q8_0_packed32 +#define DATA_A_QUANT_LEGACY #endif #define QUANT_K_Q8_1 32 @@ -226,21 +231,21 @@ struct block_q2_K { uint8_t scales[QUANT_K_Q2_K/16]; uint8_t qs[QUANT_K_Q2_K/4]; - f16vec2 d; + f16vec2 dm; }; struct block_q2_K_packed16 { uint16_t scales[QUANT_K_Q2_K/16/2]; uint16_t qs[QUANT_K_Q2_K/4/2]; - f16vec2 d; + f16vec2 dm; }; struct block_q2_K_packed32 { uint32_t scales[QUANT_K_Q2_K/16/4]; uint32_t qs[QUANT_K_Q2_K/4/4]; - f16vec2 d; + f16vec2 dm; }; #if defined(DATA_A_Q2_K) @@ -249,6 +254,8 @@ struct block_q2_K_packed32 #define A_TYPE block_q2_K #define A_TYPE_PACKED16 block_q2_K_packed16 #define A_TYPE_PACKED32 block_q2_K_packed32 +#define SCALES_PER_32 2 +#define DATA_A_QUANT_K #endif #define QUANT_K_Q3_K 256 @@ -274,27 +281,28 @@ struct block_q3_K_packed16 #define QUANT_R 1 #define A_TYPE block_q3_K #define A_TYPE_PACKED16 block_q3_K_packed16 +#define DATA_A_QUANT_K #endif #define QUANT_K_Q4_K 256 struct block_q4_K { - f16vec2 d; + f16vec2 dm; uint8_t scales[3*QUANT_K_Q4_K/64]; uint8_t qs[QUANT_K_Q4_K/2]; }; struct block_q4_K_packed16 { - f16vec2 d; + f16vec2 dm; uint16_t scales[3*QUANT_K_Q4_K/64/2]; uint16_t qs[QUANT_K_Q4_K/2/2]; }; struct block_q4_K_packed32 { - f16vec2 d; + f16vec2 dm; uint32_t scales[3*QUANT_K_Q4_K/64/4]; uint32_t qs[QUANT_K_Q4_K/2/4]; }; @@ -310,13 +318,14 @@ struct block_q4_K_packed128 #define A_TYPE block_q4_K #define A_TYPE_PACKED16 block_q4_K_packed16 #define A_TYPE_PACKED32 block_q4_K_packed32 +#define DATA_A_QUANT_K #endif #define QUANT_K_Q5_K 256 struct block_q5_K { - f16vec2 d; + f16vec2 dm; uint8_t scales[12]; uint8_t qh[QUANT_K_Q5_K/8]; uint8_t qs[QUANT_K_Q5_K/2]; @@ -324,12 +333,20 @@ struct block_q5_K struct block_q5_K_packed16 { - f16vec2 d; + f16vec2 dm; uint16_t scales[12/2]; uint16_t qh[QUANT_K_Q5_K/8/2]; uint16_t qs[QUANT_K_Q5_K/2/2]; }; +struct block_q5_K_packed32 +{ + f16vec2 dm; + uint32_t scales[12/4]; + uint32_t qh[QUANT_K_Q5_K/8/4]; + uint32_t qs[QUANT_K_Q5_K/2/4]; +}; + struct block_q5_K_packed128 { uvec4 q5k[11]; @@ -340,6 +357,8 @@ struct block_q5_K_packed128 #define QUANT_R 1 #define A_TYPE block_q5_K #define A_TYPE_PACKED16 block_q5_K_packed16 +#define A_TYPE_PACKED32 block_q5_K_packed32 +#define DATA_A_QUANT_K #endif #define QUANT_K_Q6_K 256 @@ -356,7 +375,7 @@ struct block_q6_K_packed16 { uint16_t ql[QUANT_K_Q6_K/2/2]; uint16_t qh[QUANT_K_Q6_K/4/2]; - int8_t scales[QUANT_K_Q6_K/16]; + int16_t scales[QUANT_K_Q6_K/16/2]; float16_t d; }; @@ -365,6 +384,7 @@ struct block_q6_K_packed16 #define QUANT_R 1 #define A_TYPE block_q6_K #define A_TYPE_PACKED16 block_q6_K_packed16 +#define DATA_A_QUANT_K #endif // IQuants @@ -1363,18 +1383,11 @@ struct block_mxfp4 uint8_t qs[QUANT_K_MXFP4/2]; }; -//struct block_mxfp4_packed16 -//{ -// uint8_t e; -// uint16_t qs[QUANT_K_MXFP4/2/2]; -//}; - #if defined(DATA_A_MXFP4) #define QUANT_K QUANT_K_MXFP4 #define QUANT_R QUANT_R_MXFP4 #define QUANT_AUXF 1 #define A_TYPE block_mxfp4 -//#define A_TYPE_PACKED16 block_mxfp4_packed16 #endif #if defined(DATA_A_IQ4_NL) || defined(DATA_A_IQ4_XS) @@ -1397,12 +1410,12 @@ void init_iq_shmem(uvec3 wgsize) #endif #if defined(DATA_A_MXFP4) -const FLOAT_TYPE kvalues_mxfp4_const[16] = { - FLOAT_TYPE(0.0f), FLOAT_TYPE(0.5f), FLOAT_TYPE(1.0f), FLOAT_TYPE(1.5f), FLOAT_TYPE(2.0f), FLOAT_TYPE(3.0f), FLOAT_TYPE(4.0f), FLOAT_TYPE(6.0f), - FLOAT_TYPE(-0.0f), FLOAT_TYPE(-0.5f), FLOAT_TYPE(-1.0f), FLOAT_TYPE(-1.5f), FLOAT_TYPE(-2.0f), FLOAT_TYPE(-3.0f), FLOAT_TYPE(-4.0f), FLOAT_TYPE(-6.0f) +const int8_t kvalues_mxfp4_const[16] = { + int8_t(0), int8_t(1), int8_t(2), int8_t(3), int8_t(4), int8_t(6), int8_t(8), int8_t(12), + int8_t(0), int8_t(-1), int8_t(-2), int8_t(-3), int8_t(-4), int8_t(-6), int8_t(-8), int8_t(-12), }; -shared FLOAT_TYPE kvalues_mxfp4[16]; +shared int8_t kvalues_mxfp4[16]; #define NEEDS_INIT_IQ_SHMEM void init_iq_shmem(uvec3 wgsize) diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 0f25ba345..e6ec589fb 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -566,7 +566,8 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c } #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) - if (!coopmat && !coopmat2 && matmul_id_type == MatMulIdType::NONE && is_legacy_quant(tname)) { + // Integer dot mmq performs better with f32 accumulators + if (!f16acc && !coopmat && !coopmat2 && (is_legacy_quant(tname) || is_k_quant(tname) || tname == "mxfp4")) { string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc); } #endif @@ -574,7 +575,7 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c } void process_shaders() { - std::map base_dict = {{"FLOAT_TYPE", "float"}}; + std::map base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}}; // matmul for (const MatMulIdType& matmul_id_type : {MatMulIdType::NONE, MatMulIdType::DEFAULT, MatMulIdType::SUBGROUP}) { @@ -841,10 +842,14 @@ void process_shaders() { string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("rope_norm_f16_rte", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}); + string_to_spv("rope_norm_f32_f16", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}); + string_to_spv("rope_norm_f32_f16_rte", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}); string_to_spv("rope_neox_f32", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("rope_neox_f16_rte", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}); + string_to_spv("rope_neox_f32_f16", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}); + string_to_spv("rope_neox_f32_f16_rte", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}); string_to_spv("rope_multi_f32", "rope_multi.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("rope_multi_f16", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});