diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index a7476b10..d0e99b6f 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -25,6 +25,9 @@ #include #include +#define ROUNDUP_POW2(x, pow2) (((x) + ((pow2) - 1)) & ~((pow2) - 1)) +#define CEIL_DIV(M, N) (((M) + (N) - 1) / (N)) + #ifdef GGML_WEBGPU_DEBUG # define WEBGPU_LOG_DEBUG(msg) std::cout << msg << std::endl # define WEBGPU_DEBUG_BUF_ELEMS 32 @@ -64,6 +67,9 @@ /* Constants */ +// Track https://github.com/gpuweb/gpuweb/issues/5315 for fixes to implementations so this can be removed. +#define WEBGPU_MAX_WG_SIZE 288 + #define WEBGPU_MUL_MAT_WG_SIZE 256 #define WEBGPU_NUM_PARAM_BUFS 32u #define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 8u @@ -272,36 +278,33 @@ struct webgpu_context_struct { wgpu::SubgroupMatrixConfig subgroup_matrix_config; #endif - // Separate this out from limits since on some Metal systems, the limit returned by - // querying the limits is higher than the actual allowed maximum. - uint32_t max_wg_size_x; - std::recursive_mutex mutex; std::atomic_uint inflight_threads = 0; webgpu_buf_pool param_buf_pool; webgpu_buf_pool set_rows_error_buf_pool; - webgpu_pipeline memset_pipeline; + std::map memset_pipelines; // variant or type index std::map>> mul_mat_pipelines; // src0_type, src1_type, vectorized std::map>> mul_mat_vec_pipelines; // src0_type, src1_type, vectorized - webgpu_pipeline mul_mat_pipeline[30][2]; - webgpu_pipeline set_rows_pipeline[1][2]; // dst->type, vectorized - webgpu_pipeline get_rows_pipeline[30]; - webgpu_pipeline get_rows_f32_no_vec_pipeline; - webgpu_pipeline cpy_pipeline[2][2]; // src type, dst type - webgpu_pipeline add_pipeline[2][2]; // type, inplace - webgpu_pipeline sub_pipeline[2][2]; // type, inplace - webgpu_pipeline mul_pipeline[2][2]; // type, inplace - webgpu_pipeline div_pipeline[2][2]; // type, inplace - webgpu_pipeline rms_norm_pipeline[2]; // inplace - webgpu_pipeline rope_pipeline[2][2][2]; // type, ff, inplace - webgpu_pipeline glu_pipeline[7][2][2]; // glu-op, type, split - webgpu_pipeline scale_pipeline[2]; // inplace - webgpu_pipeline soft_max_pipeline[3][2][2]; // (no_mask, f32_mask, f16_mask), has_sink, inplace + std::map> set_rows_pipelines; // dst_type, vectorized + std::map> get_rows_pipelines; // src_type, vectorized + + std::map> cpy_pipelines; // src_type, dst_type + std::map> add_pipelines; // type, inplace + std::map> sub_pipelines; // type, inplace + std::map> mul_pipelines; // type, inplace + std::map> div_pipelines; // type, inplace + + std::map rms_norm_pipelines; // inplace + std::map>> rope_pipelines; // type, ff, inplace + std::map>> glu_pipelines; // glu_op, type, split + std::map scale_pipelines; // inplace + std::map>> soft_max_pipelines; // mask_type, has_sink, inplace + std::map>> unary_pipelines; // unary_op, type, inplace size_t memset_bytes_per_thread; @@ -381,35 +384,10 @@ static std::string ggml_webgpu_process_shader_repls(const char * return s; } -static void ggml_webgpu_create_pipeline(wgpu::Device & device, - webgpu_pipeline & pipeline, - const char * shader_code, - const char * label, - const std::vector & constants = {}) { - wgpu::ShaderSourceWGSL shader_source; - shader_source.code = shader_code; - - wgpu::ShaderModuleDescriptor shader_desc; - shader_desc.nextInChain = &shader_source; - - wgpu::ShaderModule shader_module = device.CreateShaderModule(&shader_desc); - - wgpu::ComputePipelineDescriptor pipeline_desc; - pipeline_desc.label = label; - pipeline_desc.compute.module = shader_module; - pipeline_desc.compute.entryPoint = "main"; // Entry point in the WGSL code - pipeline_desc.layout = nullptr; // nullptr means auto layout - if (constants.size() > 0) { - pipeline_desc.compute.constants = constants.data(); - pipeline_desc.compute.constantCount = constants.size(); - } - pipeline = { device.CreateComputePipeline(&pipeline_desc), label }; -} - -static webgpu_pipeline ggml_webgpu_create_pipeline2(wgpu::Device & device, - const char * shader_code, - const char * label, - const std::vector & constants = {}) { +static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device, + const char * shader_code, + const char * label, + const std::vector & constants = {}) { wgpu::ShaderSourceWGSL shader_source; shader_source.code = shader_code; @@ -678,10 +656,10 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_context & ctx, std::vector entries = { { .binding = 0, .buffer = buf, .offset = 0, .size = buf.GetSize() } }; - size_t bytes_per_wg = ctx->max_wg_size_x * ctx->memset_bytes_per_thread; - uint32_t wg_x = ((size + 3) + bytes_per_wg - 1) / bytes_per_wg; + size_t bytes_per_wg = WEBGPU_MAX_WG_SIZE * ctx->memset_bytes_per_thread; + uint32_t wg_x = CEIL_DIV(size + 3, bytes_per_wg); - webgpu_command command = ggml_backend_webgpu_build(ctx, ctx->memset_pipeline, params, entries, wg_x); + webgpu_command command = ggml_backend_webgpu_build(ctx, ctx->memset_pipelines[0], params, entries, wg_x); std::vector futures = { ggml_backend_webgpu_submit(ctx, { command }) }; ggml_backend_webgpu_wait(ctx, futures); } @@ -763,8 +741,7 @@ static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, ggml_tensor } static size_t ggml_webgpu_tensor_binding_size(webgpu_context & ctx, ggml_tensor * t) { - return (ggml_nbytes(t) + ggml_webgpu_tensor_misalignment(ctx, t) + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & - ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1); + return ROUNDUP_POW2(ggml_nbytes(t) + ggml_webgpu_tensor_misalignment(ctx, t), WEBGPU_STORAGE_BUF_BINDING_MULT); } // Used to determine if two tensors are the same for in-place operations @@ -800,9 +777,8 @@ static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, g .size = ggml_webgpu_tensor_binding_size(ctx, dst) } }; - size_t max_wg_size = ctx->max_wg_size_x; - uint32_t wg_x = (ne + max_wg_size - 1) / max_wg_size; - return ggml_backend_webgpu_build(ctx, ctx->cpy_pipeline[src->type][dst->type], params, entries, wg_x); + uint32_t wg_x = CEIL_DIV(ne, WEBGPU_MAX_WG_SIZE); + return ggml_backend_webgpu_build(ctx, ctx->cpy_pipelines[src->type][dst->type], params, entries, wg_x); } static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, @@ -851,10 +827,8 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, { .binding = 3, .buffer = error_bufs.dev_buf, .offset = 0, .size = error_bufs.dev_buf.GetSize() } }; - size_t max_wg_size = ctx->max_wg_size_x; - int vectorized = src->ne[0] % 4 == 0; - webgpu_pipeline pipeline = ctx->set_rows_pipeline[0][vectorized]; + webgpu_pipeline pipeline = ctx->set_rows_pipelines[0][vectorized]; uint32_t threads; if (vectorized) { threads = (src->ne[1] * src->ne[2] * src->ne[3]) * (src->ne[0] / 4); @@ -862,7 +836,7 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3]; } - uint32_t wg_x = (threads + max_wg_size - 1) / max_wg_size; + uint32_t wg_x = CEIL_DIV(threads, WEBGPU_MAX_WG_SIZE); return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, 1, error_bufs); } @@ -902,13 +876,10 @@ static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, dst) } }; - size_t max_wg_size = ctx->max_wg_size_x; - uint32_t wg_x = (dst->ne[1] * dst->ne[2] * dst->ne[3] + max_wg_size - 1) / max_wg_size; + uint32_t wg_x = CEIL_DIV(dst->ne[1] * dst->ne[2] * dst->ne[3], WEBGPU_MAX_WG_SIZE); - webgpu_pipeline pipeline = ctx->get_rows_pipeline[src->type]; - if (src->type == GGML_TYPE_F32 && dst->ne[0] % 4 != 0) { - pipeline = ctx->get_rows_f32_no_vec_pipeline; - } + uint32_t vectorized = src->type == GGML_TYPE_F32 && dst->ne[0] % 4 == 0; + webgpu_pipeline pipeline = ctx->get_rows_pipelines[src->type][vectorized]; return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } @@ -950,10 +921,9 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, dst) }, }; - webgpu_pipeline pipeline = ctx->mul_mat_pipeline[src0->type][src1->type]; + webgpu_pipeline pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][0]; - uint32_t wg_x = - (dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3] + WEBGPU_MUL_MAT_WG_SIZE - 1) / WEBGPU_MUL_MAT_WG_SIZE; + uint32_t wg_x = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], WEBGPU_MUL_MAT_WG_SIZE); uint32_t wg_y = 1; bool use_fast = false; @@ -980,15 +950,13 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, int vectorized = src0->ne[0] % 4 == 0 && dst->ne[0] % 4 == 0 && dst->ne[1] % 4 == 0; if (dst->ne[1] == 1) { // We don't support vectorized mul_mat_vec for quantized types - vectorized = vectorized && (src0->type < 2); - pipeline = ctx->mul_mat_vec_pipelines[src0->type][src1->type][vectorized]; - uint32_t batches = dst->ne[2] * dst->ne[3]; - uint32_t output_groups = - (dst->ne[0] + WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG - 1) / WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG; - uint32_t total_wg = output_groups * batches; - wg_x = total_wg % ctx->limits.maxComputeWorkgroupsPerDimension; - wg_y = (total_wg + ctx->limits.maxComputeWorkgroupsPerDimension - 1) / - ctx->limits.maxComputeWorkgroupsPerDimension; + vectorized = vectorized && (src0->type < 2); + pipeline = ctx->mul_mat_vec_pipelines[src0->type][src1->type][vectorized]; + uint32_t batches = dst->ne[2] * dst->ne[3]; + uint32_t output_groups = CEIL_DIV(dst->ne[0], WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG); + uint32_t total_wg = output_groups * batches; + wg_x = total_wg % ctx->limits.maxComputeWorkgroupsPerDimension; + wg_y = CEIL_DIV(total_wg, ctx->limits.maxComputeWorkgroupsPerDimension); } else { pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][vectorized]; uint32_t wg_m; @@ -998,16 +966,16 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, // The total number of subgroups/workgroups needed per matrix. uint32_t wg_m_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_M * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M * ctx->subgroup_matrix_config.M; - wg_m = (dst->ne[0] + wg_m_sg_tile - 1) / wg_m_sg_tile; + wg_m = CEIL_DIV(dst->ne[0], wg_m_sg_tile); uint32_t wg_n_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_N * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N * ctx->subgroup_matrix_config.N; - wg_n = (dst->ne[1] + wg_n_sg_tile - 1) / wg_n_sg_tile; + wg_n = CEIL_DIV(dst->ne[1], wg_n_sg_tile); } else { #endif uint32_t tile_m_s = WEBGPU_MUL_MAT_TILE_M * WEBGPU_MUL_MAT_WG_SIZE_M; uint32_t tile_n_s = WEBGPU_MUL_MAT_TILE_N * WEBGPU_MUL_MAT_WG_SIZE_N; - wg_m = (dst->ne[0] + tile_m_s - 1) / tile_m_s; - wg_n = (dst->ne[1] + tile_n_s - 1) / tile_n_s; + wg_m = CEIL_DIV(dst->ne[0], tile_m_s); + wg_n = CEIL_DIV(dst->ne[1], tile_n_s); #ifndef __EMSCRIPTEN__ } #endif @@ -1018,6 +986,60 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); } +static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { + uint32_t ne = (uint32_t) ggml_nelements(dst); + ggml_unary_op unary_op = ggml_get_unary_op(dst); + uint32_t inplace = ggml_webgpu_tensor_equal(src, dst); + + std::vector params = { + ne, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + // Convert byte-strides to element-strides + (uint32_t) (src->nb[0] / ggml_type_size(src->type)), (uint32_t) (src->nb[1] / ggml_type_size(src->type)), + (uint32_t) (src->nb[2] / ggml_type_size(src->type)), (uint32_t) (src->nb[3] / ggml_type_size(src->type)), + (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), + // Logical shapes + (uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) dst->ne[0], + (uint32_t) dst->ne[1], (uint32_t) dst->ne[2] + }; + + switch (unary_op) { + case GGML_UNARY_OP_XIELU: + { + // Get float parameters and reinterpret their bit patterns as uint32_t + // for passing through the params buffer + float alpha_n = ggml_get_op_params_f32(dst, 1); + float alpha_p = ggml_get_op_params_f32(dst, 2); + float beta = ggml_get_op_params_f32(dst, 3); + float eps = ggml_get_op_params_f32(dst, 4); + params.push_back(*reinterpret_cast(&alpha_n)); + params.push_back(*reinterpret_cast(&alpha_p)); + params.push_back(*reinterpret_cast(&beta)); + params.push_back(*reinterpret_cast(&eps)); + break; + } + default: + break; + } + + std::vector entries = { + { .binding = 0, + .buffer = ggml_webgpu_tensor_buf(src), + .offset = ggml_webgpu_tensor_align_offset(ctx, src), + .size = ggml_webgpu_tensor_binding_size(ctx, src) }, + }; + if (!inplace) { + entries.push_back({ .binding = 1, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + } + + uint32_t wg_x = CEIL_DIV(ne, WEBGPU_MAX_WG_SIZE); + return ggml_backend_webgpu_build(ctx, ctx->unary_pipelines[unary_op][dst->type][inplace], params, entries, wg_x); +} + static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, @@ -1059,8 +1081,7 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); } - size_t max_wg_size = ctx->max_wg_size_x; - uint32_t wg_x = (ggml_nelements(dst) + max_wg_size - 1) / max_wg_size; + uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE); return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } @@ -1096,7 +1117,7 @@ static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * s .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); } - return ggml_backend_webgpu_build(ctx, ctx->rms_norm_pipeline[inplace], params, entries, ggml_nrows(src)); + return ggml_backend_webgpu_build(ctx, ctx->rms_norm_pipelines[inplace], params, entries, ggml_nrows(src)); } static webgpu_command ggml_webgpu_rope(webgpu_context & ctx, @@ -1181,9 +1202,8 @@ static webgpu_command ggml_webgpu_rope(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); } - webgpu_pipeline pipeline = ctx->rope_pipeline[dst->type][has_freq_factor][inplace]; - size_t max_wg_size = ctx->max_wg_size_x; - uint32_t wg_x = (ggml_nelements(src0) / 2 + max_wg_size - 1) / max_wg_size; + webgpu_pipeline pipeline = ctx->rope_pipelines[dst->type][has_freq_factor][inplace]; + uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE); return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } @@ -1234,9 +1254,8 @@ static webgpu_command ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, .offset = ggml_webgpu_tensor_align_offset(ctx, dst), .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); - webgpu_pipeline pipeline = ctx->glu_pipeline[ggml_get_glu_op(dst)][dst->type][split]; - size_t max_wg_size = ctx->max_wg_size_x; - uint32_t wg_x = (ggml_nelements(dst) + max_wg_size - 1) / max_wg_size; + webgpu_pipeline pipeline = ctx->glu_pipelines[ggml_get_glu_op(dst)][dst->type][split]; + uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE); return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } @@ -1273,9 +1292,8 @@ static webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); } - size_t max_wg_size = ctx->max_wg_size_x; - uint32_t wg_x = (ggml_nelements(dst) + max_wg_size - 1) / max_wg_size; - return ggml_backend_webgpu_build(ctx, ctx->scale_pipeline[inplace], params, entries, wg_x); + uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE); + return ggml_backend_webgpu_build(ctx, ctx->scale_pipelines[inplace], params, entries, wg_x); } static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx, @@ -1347,7 +1365,7 @@ static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); } - return ggml_backend_webgpu_build(ctx, ctx->soft_max_pipeline[mask_type][has_sink][inplace], params, entries, + return ggml_backend_webgpu_build(ctx, ctx->soft_max_pipelines[mask_type][has_sink][inplace], params, entries, ggml_nrows(dst)); } @@ -1382,22 +1400,22 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, case GGML_OP_ADD: { int inplace = ggml_webgpu_tensor_equal(src0, node); - return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_pipeline[node->type][inplace], inplace); + return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_pipelines[node->type][inplace], inplace); } case GGML_OP_SUB: { int inplace = ggml_webgpu_tensor_equal(src0, node); - return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->sub_pipeline[node->type][inplace], inplace); + return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->sub_pipelines[node->type][inplace], inplace); } case GGML_OP_MUL: { int inplace = ggml_webgpu_tensor_equal(src0, node); - return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_pipeline[node->type][inplace], inplace); + return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_pipelines[node->type][inplace], inplace); } case GGML_OP_DIV: { int inplace = ggml_webgpu_tensor_equal(src0, node); - return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->div_pipeline[node->type][inplace], inplace); + return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->div_pipelines[node->type][inplace], inplace); } case GGML_OP_RMS_NORM: return ggml_webgpu_rms_norm(ctx, src0, node); @@ -1409,6 +1427,8 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, return ggml_webgpu_scale(ctx, src0, node); case GGML_OP_SOFT_MAX: return ggml_webgpu_soft_max(ctx, src0, src1, src2, node); + case GGML_OP_UNARY: + return ggml_webgpu_unary_op(ctx, src0, node); default: return std::nullopt; } @@ -1641,8 +1661,7 @@ static ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_b ggml_backend_webgpu_device_context * ctx = static_cast(buft->device->context); wgpu::Buffer buf; - ggml_webgpu_create_buffer(ctx->webgpu_ctx->device, buf, - (size + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1), + ggml_webgpu_create_buffer(ctx->webgpu_ctx->device, buf, ROUNDUP_POW2(size, WEBGPU_STORAGE_BUF_BINDING_MULT), wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst, buf_name.c_str()); @@ -1717,58 +1736,64 @@ static std::vector ggml_webgpu_wg_size_entry(uint32_t wg_si static void ggml_webgpu_init_memset_pipeline(webgpu_context & webgpu_ctx) { // we use the maximum workgroup size for the memset pipeline - size_t max_wg_size = webgpu_ctx->max_wg_size_x; - size_t max_threads = max_wg_size * webgpu_ctx->limits.maxComputeWorkgroupsPerDimension; + size_t max_threads = WEBGPU_MAX_WG_SIZE * webgpu_ctx->limits.maxComputeWorkgroupsPerDimension; // Size the bytes_per_thread so that the largest buffer size can be handled - webgpu_ctx->memset_bytes_per_thread = - (webgpu_ctx->limits.maxStorageBufferBindingSize + max_threads - 1) / max_threads; + webgpu_ctx->memset_bytes_per_thread = CEIL_DIV(webgpu_ctx->limits.maxStorageBufferBindingSize, max_threads); std::vector constants(2); - constants[0].key = "wg_size"; - constants[0].value = max_wg_size; - constants[1].key = "bytes_per_thread"; - constants[1].value = webgpu_ctx->memset_bytes_per_thread; - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->memset_pipeline, wgsl_memset, "memset", constants); + constants[0].key = "wg_size"; + constants[0].value = WEBGPU_MAX_WG_SIZE; + constants[1].key = "bytes_per_thread"; + constants[1].value = webgpu_ctx->memset_bytes_per_thread; + webgpu_ctx->memset_pipelines[0] = ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_memset, "memset", constants); } static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_0][GGML_TYPE_F32], - wgsl_mul_mat_q4_0_f32, "mul_mat_q4_0_f32"); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_1][GGML_TYPE_F32], - wgsl_mul_mat_q4_1_f32, "mul_mat_q4_1_f32"); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q5_0][GGML_TYPE_F32], - wgsl_mul_mat_q5_0_f32, "mul_mat_q5_0_f32"); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q5_1][GGML_TYPE_F32], - wgsl_mul_mat_q5_1_f32, "mul_mat_q5_1_f32"); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q8_0][GGML_TYPE_F32], - wgsl_mul_mat_q8_0_f32, "mul_mat_q8_0_f32"); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q2_K][GGML_TYPE_F32], - wgsl_mul_mat_q2_k_f32, "mul_mat_q2_k_f32"); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q3_K][GGML_TYPE_F32], - wgsl_mul_mat_q3_k_f32, "mul_mat_q3_k_f32"); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_K][GGML_TYPE_F32], - wgsl_mul_mat_q4_k_f32, "mul_mat_q4_k_f32"); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q5_K][GGML_TYPE_F32], - wgsl_mul_mat_q5_k_f32, "mul_mat_q5_k_f32"); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q6_K][GGML_TYPE_F32], - wgsl_mul_mat_q6_k_f32, "mul_mat_q6_k_f32"); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ2_XXS][GGML_TYPE_F32], - wgsl_mul_mat_iq2_xxs_f32, "mul_mat_iq2_xxs_f32"); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ2_XS][GGML_TYPE_F32], - wgsl_mul_mat_iq2_xs_f32, "mul_mat_iq2_xs_f32"); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ2_S][GGML_TYPE_F32], - wgsl_mul_mat_iq2_s_f32, "mul_mat_iq2_s_f32"); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ3_XXS][GGML_TYPE_F32], - wgsl_mul_mat_iq3_xxs_f32, "mul_mat_iq3_xxs_f32"); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ3_S][GGML_TYPE_F32], - wgsl_mul_mat_iq3_s_f32, "mul_mat_iq3_s_f32"); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ1_S][GGML_TYPE_F32], - wgsl_mul_mat_iq1_s_f32, "mul_mat_iq1_s_f32"); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ1_M][GGML_TYPE_F32], - wgsl_mul_mat_iq1_m_f32, "mul_mat_iq1_m_f32"); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ4_NL][GGML_TYPE_F32], - wgsl_mul_mat_iq4_nl_f32, "mul_mat_iq4_nl_f32"); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ4_XS][GGML_TYPE_F32], - wgsl_mul_mat_iq4_xs_f32, "mul_mat_iq4_xs_f32"); + // Q4/Q5/Q8 classic quantizations + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q4_0_f32, "mul_mat_q4_0_f32"); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_1][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q4_1_f32, "mul_mat_q4_1_f32"); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_0][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q5_0_f32, "mul_mat_q5_0_f32"); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_1][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q5_1_f32, "mul_mat_q5_1_f32"); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q8_0][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q8_0_f32, "mul_mat_q8_0_f32"); + + // K-quantizations + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q2_K][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q2_k_f32, "mul_mat_q2_k_f32"); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q3_K][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q3_k_f32, "mul_mat_q3_k_f32"); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_K][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q4_k_f32, "mul_mat_q4_k_f32"); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_K][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q5_k_f32, "mul_mat_q5_k_f32"); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q6_K][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q6_k_f32, "mul_mat_q6_k_f32"); + + // IQ quantizations (2-, 3-, 4-bit variants) + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_XXS][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq2_xxs_f32, "mul_mat_iq2_xxs_f32"); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_XS][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq2_xs_f32, "mul_mat_iq2_xs_f32"); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_S][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq2_s_f32, "mul_mat_iq2_s_f32"); + + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ3_XXS][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq3_xxs_f32, "mul_mat_iq3_xxs_f32"); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ3_S][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq3_s_f32, "mul_mat_iq3_s_f32"); + + // 1-bit and 4-bit IQ variants + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ1_S][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq1_s_f32, "mul_mat_iq1_s_f32"); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ1_M][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq1_m_f32, "mul_mat_iq1_m_f32"); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ4_NL][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq4_nl_f32, "mul_mat_iq4_nl_f32"); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ4_XS][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq4_xs_f32, "mul_mat_iq4_xs_f32"); std::string proc_mul_mat_f32_f32; std::string proc_mul_mat_f32_f32_vec; @@ -1828,21 +1853,21 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { } #endif - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline( webgpu_ctx->device, proc_mul_mat_f32_f32.c_str(), "mul_mat_f32_f32", mul_mat_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline2( + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( webgpu_ctx->device, proc_mul_mat_f32_f32_vec.c_str(), "mul_mat_f32_f32_vec", mul_mat_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline( webgpu_ctx->device, proc_mul_mat_f16_f32.c_str(), "mul_mat_f16_f32", mul_mat_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline2( + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( webgpu_ctx->device, proc_mul_mat_f16_f32_vec.c_str(), "mul_mat_f16_f32_vec", mul_mat_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline( webgpu_ctx->device, proc_mul_mat_f16_f16.c_str(), "mul_mat_f16_f16", mul_mat_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline2( + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline( webgpu_ctx->device, proc_mul_mat_f16_f16_vec.c_str(), "mul_mat_f16_f16_vec", mul_mat_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline( webgpu_ctx->device, proc_mul_mat_q4_0_f32.c_str(), "mul_mat_q4_0_f32", mul_mat_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline2( + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( webgpu_ctx->device, proc_mul_mat_q4_0_f32_vec.c_str(), "mul_mat_q4_0_f32_vec", mul_mat_constants); std::vector mul_mat_vec_constants(3); @@ -1853,257 +1878,444 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { mul_mat_vec_constants[2].key = "OUTPUTS_PER_WG"; mul_mat_vec_constants[2].value = WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG; - webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline( webgpu_ctx->device, wgsl_mul_mat_vec_f32_f32, "mul_mat_vec_f32_f32", mul_mat_vec_constants); - webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline2( + webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( webgpu_ctx->device, wgsl_mul_mat_vec_f32_f32_vec, "mul_mat_vec_f32_f32_vec", mul_mat_vec_constants); - webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline( webgpu_ctx->device, wgsl_mul_mat_vec_f16_f32, "mul_mat_vec_f16_f32", mul_mat_vec_constants); - webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline2( + webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( webgpu_ctx->device, wgsl_mul_mat_vec_f16_f32_vec, "mul_mat_vec_f16_f32_vec", mul_mat_vec_constants); - webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline( webgpu_ctx->device, wgsl_mul_mat_vec_f16_f16, "mul_mat_vec_f16_f16", mul_mat_vec_constants); - webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline2( + webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline( webgpu_ctx->device, wgsl_mul_mat_vec_f16_f16_vec, "mul_mat_vec_f16_f16_vec", mul_mat_vec_constants); - webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline( webgpu_ctx->device, wgsl_mul_mat_vec_q4_0_f32, "mul_mat_vec_q4_0_f32", mul_mat_vec_constants); } static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) { - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->set_rows_pipeline[0][0], wgsl_set_rows_f16, - "set_rows_f16", ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x)); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->set_rows_pipeline[0][1], wgsl_set_rows_f16_vec, - "set_rows_f16_vec", ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x)); + webgpu_ctx->set_rows_pipelines[0][0] = ggml_webgpu_create_pipeline( + webgpu_ctx->device, wgsl_set_rows_f16, "set_rows_f16", ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE)); + webgpu_ctx->set_rows_pipelines[0][1] = ggml_webgpu_create_pipeline( + webgpu_ctx->device, wgsl_set_rows_f16_vec, "set_rows_f16_vec", ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE)); } static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_F32], wgsl_get_rows_f32_vec, - "get_rows_f32_vec", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_f32_no_vec_pipeline, wgsl_get_rows_f32, - "get_rows_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_F16], wgsl_get_rows_f16, - "get_rows_f16", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_I32], wgsl_get_rows_i32, - "get_rows_i32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q4_0], wgsl_get_rows_q4_0, - "get_rows_q4_0", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q4_1], wgsl_get_rows_q4_1, - "get_rows_q4_1", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q5_0], wgsl_get_rows_q5_0, - "get_rows_q5_0", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q5_1], wgsl_get_rows_q5_1, - "get_rows_q5_1", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q8_0], wgsl_get_rows_q8_0, - "get_rows_q8_0", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q2_K], wgsl_get_rows_q2_k, - "get_rows_q2_k", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q3_K], wgsl_get_rows_q3_k, - "get_rows_q3_k", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q4_K], wgsl_get_rows_q4_k, - "get_rows_q4_k", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q5_K], wgsl_get_rows_q5_k, - "get_rows_q5_k", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q6_K], wgsl_get_rows_q6_k, - "get_rows_q6_k", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ2_XXS], - wgsl_get_rows_iq2_xxs, "get_rows_iq2_xxs", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ2_XS], - wgsl_get_rows_iq2_xs, "get_rows_iq2_xs", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ2_S], wgsl_get_rows_iq2_s, - "get_rows_iq2_s", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ3_XXS], - wgsl_get_rows_iq3_xxs, "get_rows_iq3_xxs", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ3_S], wgsl_get_rows_iq3_s, - "get_rows_iq3_s", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ1_S], wgsl_get_rows_iq1_s, - "get_rows_iq1_s", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ1_M], wgsl_get_rows_iq1_m, - "get_rows_iq1_m", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ4_NL], - wgsl_get_rows_iq4_nl, "get_rows_iq4_nl", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ4_XS], - wgsl_get_rows_iq4_xs, "get_rows_iq4_xs", constants); + std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); + + webgpu_ctx->get_rows_pipelines[GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_f32, "get_rows_f32", constants); + webgpu_ctx->get_rows_pipelines[GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_f32_vec, "get_rows_f32_vec", constants); + + webgpu_ctx->get_rows_pipelines[GGML_TYPE_F16][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_f16, "get_rows_f16", constants); + webgpu_ctx->get_rows_pipelines[GGML_TYPE_I32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_i32, "get_rows_i32", constants); + webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_0][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q4_0, "get_rows_q4_0", constants); + webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_1][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q4_1, "get_rows_q4_1", constants); + webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_0][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q5_0, "get_rows_q5_0", constants); + webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_1][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q5_1, "get_rows_q5_1", constants); + webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q8_0][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q8_0, "get_rows_q8_0", constants); + + webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q2_K][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q2_k, "get_rows_q2_k", constants); + webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q3_K][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q3_k, "get_rows_q3_k", constants); + webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_K][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q4_k, "get_rows_q4_k", constants); + webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_K][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q5_k, "get_rows_q5_k", constants); + webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q6_K][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q6_k, "get_rows_q6_k", constants); + + webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_XXS][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq2_xxs, "get_rows_iq2_xxs", constants); + webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_XS][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq2_xs, "get_rows_iq2_xs", constants); + webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_S][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq2_s, "get_rows_iq2_s", constants); + webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ3_XXS][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq3_xxs, "get_rows_iq3_xxs", constants); + webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ3_S][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq3_s, "get_rows_iq3_s", constants); + webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ1_S][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq1_s, "get_rows_iq1_s", constants); + webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ1_M][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq1_m, "get_rows_iq1_m", constants); + webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ4_NL][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq4_nl, "get_rows_iq4_nl", constants); + webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ4_XS][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq4_xs, "get_rows_iq4_xs", constants); } static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline[GGML_TYPE_F32][GGML_TYPE_F32], - wgsl_cpy_f32_f32, "cpy_f32_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline[GGML_TYPE_F32][GGML_TYPE_F16], - wgsl_cpy_f32_f16, "cpy_f32_f16", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline[GGML_TYPE_F16][GGML_TYPE_F32], - wgsl_cpy_f16_f32, "cpy_f16_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline[GGML_TYPE_F16][GGML_TYPE_F16], - wgsl_cpy_f16_f16, "cpy_f16_f16", constants); + std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); + + webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_F32] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_cpy_f32_f32, "cpy_f32_f32", constants); + webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_F16] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_cpy_f32_f16, "cpy_f32_f16", constants); + webgpu_ctx->cpy_pipelines[GGML_TYPE_F16][GGML_TYPE_F32] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_cpy_f16_f32, "cpy_f16_f32", constants); + webgpu_ctx->cpy_pipelines[GGML_TYPE_F16][GGML_TYPE_F16] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_cpy_f16_f16, "cpy_f16_f16", constants); } static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32][0], wgsl_add_f32, "add_f32", - constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16][0], wgsl_add_f16, "add_f16", - constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32][1], wgsl_add_f32_inplace, - "add_f32_inplace", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16][1], wgsl_add_f16_inplace, - "add_f16_inplace", constants); + std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); + + webgpu_ctx->add_pipelines[GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_add_f32, "add_f32", constants); + webgpu_ctx->add_pipelines[GGML_TYPE_F16][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_add_f16, "add_f16", constants); + webgpu_ctx->add_pipelines[GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_add_f32_inplace, "add_f32_inplace", constants); + webgpu_ctx->add_pipelines[GGML_TYPE_F16][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_add_f16_inplace, "add_f16_inplace", constants); } static void ggml_webgpu_init_sub_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F32][0], wgsl_sub_f32, "sub_f32", - constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F16][0], wgsl_sub_f16, "sub_f16", - constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F32][1], wgsl_sub_f32_inplace, - "sub_f32_inplace", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F16][1], wgsl_sub_f16_inplace, - "sub_f16_inplace", constants); + std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); + + webgpu_ctx->sub_pipelines[GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sub_f32, "sub_f32", constants); + webgpu_ctx->sub_pipelines[GGML_TYPE_F16][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sub_f16, "sub_f16", constants); + webgpu_ctx->sub_pipelines[GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sub_f32_inplace, "sub_f32_inplace", constants); + webgpu_ctx->sub_pipelines[GGML_TYPE_F16][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sub_f16_inplace, "sub_f16_inplace", constants); } static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32][0], wgsl_mul_f32, "mul_f32", - constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16][0], wgsl_mul_f16, "mul_f16", - constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32][1], wgsl_mul_f32_inplace, - "mul_f32_inplace", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16][1], wgsl_mul_f16_inplace, - "mul_f16_inplace", constants); + std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); + + webgpu_ctx->mul_pipelines[GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_f32, "mul_f32", constants); + webgpu_ctx->mul_pipelines[GGML_TYPE_F16][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_f16, "mul_f16", constants); + webgpu_ctx->mul_pipelines[GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_f32_inplace, "mul_f32_inplace", constants); + webgpu_ctx->mul_pipelines[GGML_TYPE_F16][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_f16_inplace, "mul_f16_inplace", constants); } static void ggml_webgpu_init_div_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F32][0], wgsl_div_f32, "div_f32", - constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F16][0], wgsl_div_f16, "div_f16", - constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F32][1], wgsl_div_f32_inplace, - "div_f32_inplace", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F16][1], wgsl_div_f16_inplace, - "div_f16_inplace", constants); + std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); + + webgpu_ctx->div_pipelines[GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_div_f32, "div_f32", constants); + webgpu_ctx->div_pipelines[GGML_TYPE_F16][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_div_f16, "div_f16", constants); + webgpu_ctx->div_pipelines[GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_div_f32_inplace, "div_f32_inplace", constants); + webgpu_ctx->div_pipelines[GGML_TYPE_F16][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_div_f16_inplace, "div_f16_inplace", constants); } static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) { std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_pipeline[0], wgsl_rms_norm, "rms_norm", - constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_pipeline[1], wgsl_rms_norm_inplace, - "rms_norm_inplace", constants); + + webgpu_ctx->rms_norm_pipelines[0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rms_norm, "rms_norm", constants); + webgpu_ctx->rms_norm_pipelines[1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rms_norm_inplace, "rms_norm_inplace", constants); } static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F32][0][0], wgsl_rope_f32, - "rope_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F32][0][1], - wgsl_rope_f32_inplace, "rope_f32_inplace", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F32][1][0], wgsl_rope_f32_ff, - "rope_f32_ff", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F32][1][1], - wgsl_rope_f32_ff_inplace, "rope_f32_ff_inplace", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F16][0][0], wgsl_rope_f16, - "rope_f16", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F16][0][1], - wgsl_rope_f16_inplace, "rope_f16_inplace", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F16][1][0], wgsl_rope_f16_ff, - "rope_f16_ff", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F16][1][1], - wgsl_rope_f16_ff_inplace, "rope_f16_ff_inplace", constants); + std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); + + webgpu_ctx->rope_pipelines[GGML_TYPE_F32][0][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f32, "rope_f32", constants); + webgpu_ctx->rope_pipelines[GGML_TYPE_F32][0][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f32_inplace, "rope_f32_inplace", constants); + webgpu_ctx->rope_pipelines[GGML_TYPE_F32][1][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f32_ff, "rope_f32_ff", constants); + webgpu_ctx->rope_pipelines[GGML_TYPE_F32][1][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f32_ff_inplace, "rope_f32_ff_inplace", constants); + + webgpu_ctx->rope_pipelines[GGML_TYPE_F16][0][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f16, "rope_f16", constants); + webgpu_ctx->rope_pipelines[GGML_TYPE_F16][0][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f16_inplace, "rope_f16_inplace", constants); + webgpu_ctx->rope_pipelines[GGML_TYPE_F16][1][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f16_ff, "rope_f16_ff", constants); + webgpu_ctx->rope_pipelines[GGML_TYPE_F16][1][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f16_ff_inplace, "rope_f16_ff_inplace", constants); } static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x); - // reglu - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_REGLU][GGML_TYPE_F32][0], - wgsl_reglu_f32, "reglu_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_REGLU][GGML_TYPE_F16][0], - wgsl_reglu_f16, "reglu_f16", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_REGLU][GGML_TYPE_F32][1], - wgsl_reglu_f32_split, "reglu_f32_split", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_REGLU][GGML_TYPE_F16][1], - wgsl_reglu_f16_split, "reglu_f16_split", constants); - // geglu - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][0], - wgsl_geglu_f32, "geglu_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][0], - wgsl_geglu_f16, "geglu_f16", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][1], - wgsl_geglu_f32_split, "geglu_f32_split", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][1], - wgsl_geglu_f16_split, "geglu_f16_split", constants); - // swiglu - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][0], - wgsl_swiglu_f32, "swiglu_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][0], - wgsl_swiglu_f16, "swiglu_f16", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][1], - wgsl_swiglu_f32_split, "swiglu_f32_split", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][1], - wgsl_swiglu_f16_split, "swiglu_f16_split", constants); - // swiglu_oai - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][0], - wgsl_swiglu_oai_f32, "swiglu_oai_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][1], - wgsl_swiglu_oai_f32_split, "swiglu_oai_f32_split", constants); - // geglu_erf - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][0], - wgsl_geglu_erf_f32, "geglu_erf_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][0], - wgsl_geglu_erf_f16, "geglu_erf_f16", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][1], - wgsl_geglu_erf_f32_split, "geglu_erf_f32_split", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][1], - wgsl_geglu_erf_f16_split, "geglu_erf_f16_split", constants); - // geglu_quick - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][0], - wgsl_geglu_quick_f32, "geglu_quick_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][0], - wgsl_geglu_quick_f16, "geglu_quick_f16", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][1], - wgsl_geglu_quick_f32_split, "geglu_quick_f32_split", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][1], - wgsl_geglu_quick_f16_split, "geglu_quick_f16_split", constants); + std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); + + // REGLU + webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_reglu_f32, "reglu_f32", constants); + webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F16][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_reglu_f16, "reglu_f16", constants); + webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_reglu_f32_split, "reglu_f32_split", constants); + webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F16][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_reglu_f16_split, "reglu_f16_split", constants); + + // GEGLU + webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_f32, "geglu_f32", constants); + webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_f16, "geglu_f16", constants); + webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_f32_split, "geglu_f32_split", constants); + webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_f16_split, "geglu_f16_split", constants); + + // SWIGLU + webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_f32, "swiglu_f32", constants); + webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_f16, "swiglu_f16", constants); + webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_f32_split, "swiglu_f32_split", constants); + webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_f16_split, "swiglu_f16_split", constants); + + // SWIGLU_OAI + webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_oai_f32, "swiglu_oai_f32", constants); + webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_oai_f32_split, "swiglu_oai_f32_split", constants); + + // GEGLU_ERF + webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_erf_f32, "geglu_erf_f32", constants); + webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_erf_f16, "geglu_erf_f16", constants); + webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_erf_f32_split, "geglu_erf_f32_split", constants); + webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_erf_f16_split, "geglu_erf_f16_split", constants); + + // GEGLU_QUICK + webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_quick_f32, "geglu_quick_f32", constants); + webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_quick_f16, "geglu_quick_f16", constants); + webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_quick_f32_split, "geglu_quick_f32_split", constants); + webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_quick_f16_split, "geglu_quick_f16_split", constants); +} + +static void ggml_webgpu_init_unary_pipeline(webgpu_context & webgpu_ctx) { + std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); + + // ABS + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ABS][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_abs_f32, "abs_f32", constants); + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ABS][GGML_TYPE_F16][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_abs_f16, "abs_f16", constants); + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ABS][GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_abs_inplace_f32, "abs_inplace_f32", constants); + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ABS][GGML_TYPE_F16][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_abs_inplace_f16, "abs_inplace_f16", constants); + + // SGN + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SGN][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sgn_f32, "sgn_f32", constants); + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SGN][GGML_TYPE_F16][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sgn_f16, "sgn_f16", constants); + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SGN][GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sgn_inplace_f32, "sgn_inplace_f32", constants); + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SGN][GGML_TYPE_F16][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sgn_inplace_f16, "sgn_inplace_f16", constants); + + // NEG + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_NEG][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_neg_f32, "neg_f32", constants); + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_NEG][GGML_TYPE_F16][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_neg_f16, "neg_f16", constants); + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_NEG][GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_neg_inplace_f32, "neg_inplace_f32", constants); + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_NEG][GGML_TYPE_F16][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_neg_inplace_f16, "neg_inplace_f16", constants); + + // STEP + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_STEP][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_step_f32, "step_f32", constants); + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_STEP][GGML_TYPE_F16][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_step_f16, "step_f16", constants); + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_STEP][GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_step_inplace_f32, "step_inplace_f32", constants); + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_STEP][GGML_TYPE_F16][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_step_inplace_f16, "step_inplace_f16", constants); + + // TANH + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_TANH][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_tanh_f32, "tanh_f32", constants); + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_TANH][GGML_TYPE_F16][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_tanh_f16, "tanh_f16", constants); + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_TANH][GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_tanh_inplace_f32, "tanh_inplace_f32", constants); + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_TANH][GGML_TYPE_F16][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_tanh_inplace_f16, "tanh_inplace_f16", constants); + + // ELU + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ELU][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_elu_f32, "elu_f32", constants); + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ELU][GGML_TYPE_F16][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_elu_f16, "elu_f16", constants); + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ELU][GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_elu_inplace_f32, "elu_inplace_f32", constants); + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ELU][GGML_TYPE_F16][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_elu_inplace_f16, "elu_inplace_f16", constants); + + // RELU + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_RELU][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_relu_f32, "relu_f32", constants); + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_RELU][GGML_TYPE_F16][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_relu_f16, "relu_f16", constants); + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_RELU][GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_relu_inplace_f32, "relu_inplace_f32", constants); + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_RELU][GGML_TYPE_F16][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_relu_inplace_f16, "relu_inplace_f16", constants); + + // SIGMOID + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sigmoid_f32, "sigmoid_f32", constants); + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F16][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sigmoid_f16, "sigmoid_f16", constants); + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sigmoid_inplace_f32, "sigmoid_inplace_f32", constants); + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F16][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sigmoid_inplace_f16, "sigmoid_inplace_f16", constants); + + // GELU + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_f32, "gelu_f32", constants); + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU][GGML_TYPE_F16][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_f16, "gelu_f16", constants); + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU][GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_inplace_f32, "gelu_inplace_f32", constants); + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU][GGML_TYPE_F16][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_inplace_f16, "gelu_inplace_f16", constants); + + // GELU_QUICK + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_quick_f32, "gelu_quick_f32", constants); + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F16][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_quick_f16, "gelu_quick_f16", constants); + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( + webgpu_ctx->device, wgsl_gelu_quick_inplace_f32, "gelu_quick_inplace_f32", constants); + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline( + webgpu_ctx->device, wgsl_gelu_quick_inplace_f16, "gelu_quick_inplace_f16", constants); + + // SILU + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SILU][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_silu_f32, "silu_f32", constants); + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SILU][GGML_TYPE_F16][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_silu_f16, "silu_f16", constants); + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SILU][GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_silu_inplace_f32, "silu_inplace_f32", constants); + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SILU][GGML_TYPE_F16][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_silu_inplace_f16, "silu_inplace_f16", constants); + + // HARDSWISH + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_hardswish_f32, "hardswish_f32", constants); + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F16][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_hardswish_f16, "hardswish_f16", constants); + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_hardswish_inplace_f32, "hardswish_inplace_f32", constants); + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F16][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_hardswish_inplace_f16, "hardswish_inplace_f16", constants); + + // HARDSIGMOID + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_hardsigmoid_f32, "hardsigmoid_f32", constants); + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F16][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_hardsigmoid_f16, "hardsigmoid_f16", constants); + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( + webgpu_ctx->device, wgsl_hardsigmoid_inplace_f32, "hardsigmoid_inplace_f32", constants); + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline( + webgpu_ctx->device, wgsl_hardsigmoid_inplace_f16, "hardsigmoid_inplace_f16", constants); + + // EXP + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_EXP][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_exp_f32, "exp_f32", constants); + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_EXP][GGML_TYPE_F16][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_exp_f16, "exp_f16", constants); + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_EXP][GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_exp_inplace_f32, "exp_inplace_f32", constants); + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_EXP][GGML_TYPE_F16][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_exp_inplace_f16, "exp_inplace_f16", constants); + + // GELU_ERF + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_erf_f32, "gelu_erf_f32", constants); + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F16][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_erf_f16, "gelu_erf_f16", constants); + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_erf_inplace_f32, "gelu_erf_inplace_f32", constants); + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F16][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_erf_inplace_f16, "gelu_erf_inplace_f16", constants); + + // XIELU + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_XIELU][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_xielu_f32, "xielu_f32", constants); + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_XIELU][GGML_TYPE_F16][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_xielu_f16, "xielu_f16", constants); + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_XIELU][GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_xielu_inplace_f32, "xielu_inplace_f32", constants); + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_XIELU][GGML_TYPE_F16][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_xielu_inplace_f16, "xielu_inplace_f16", constants); } static void ggml_webgpu_init_scale_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->scale_pipeline[0], wgsl_scale_f32, "scale_f32", - constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->scale_pipeline[1], wgsl_scale_f32_inplace, - "scale_f32_inplace", constants); + std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); + + webgpu_ctx->scale_pipelines[0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_scale_f32, "scale_f32", constants); + webgpu_ctx->scale_pipelines[1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_scale_f32_inplace, "scale_f32_inplace", constants); } static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) { std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[2][0][0], wgsl_soft_max_f32, - "soft_max_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[2][0][1], wgsl_soft_max_f32_inplace, - "soft_max_f32_inplace", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[2][1][0], wgsl_soft_max_f32_sink, - "soft_max_f32_sink", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[2][1][1], - wgsl_soft_max_f32_sink_inplace, "soft_max_f32_sink_inplace", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[0][0][0], wgsl_soft_max_f32_mask_f32, - "soft_max_f32_mask_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[0][0][1], - wgsl_soft_max_f32_mask_f32_inplace, "soft_max_f32_mask_f32_inplace", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[1][0][0], wgsl_soft_max_f32_mask_f16, - "soft_max_f32_mask_f16", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[1][0][1], - wgsl_soft_max_f32_mask_f16_inplace, "soft_max_f32_mask_f16_inplace", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[0][1][0], - wgsl_soft_max_f32_mask_f32_sink, "soft_max_f32_mask_f32_sink", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[0][1][1], - wgsl_soft_max_f32_mask_f32_sink_inplace, "soft_max_f32_mask_f32_sink_inplace", - constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[1][1][0], - wgsl_soft_max_f32_mask_f16_sink, "soft_max_f32_mask_f16_sink", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[1][1][1], - wgsl_soft_max_f32_mask_f16_sink_inplace, "soft_max_f32_mask_f16_sink_inplace", - constants); + + // f32 (no mask) + webgpu_ctx->soft_max_pipelines[2][0][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_soft_max_f32, "soft_max_f32", constants); + webgpu_ctx->soft_max_pipelines[2][0][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_soft_max_f32_inplace, "soft_max_f32_inplace", constants); + webgpu_ctx->soft_max_pipelines[2][1][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_soft_max_f32_sink, "soft_max_f32_sink", constants); + webgpu_ctx->soft_max_pipelines[2][1][1] = ggml_webgpu_create_pipeline( + webgpu_ctx->device, wgsl_soft_max_f32_sink_inplace, "soft_max_f32_sink_inplace", constants); + + // f32 mask (mask_type = 0) + webgpu_ctx->soft_max_pipelines[0][0][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_soft_max_f32_mask_f32, "soft_max_f32_mask_f32", constants); + webgpu_ctx->soft_max_pipelines[0][0][1] = ggml_webgpu_create_pipeline( + webgpu_ctx->device, wgsl_soft_max_f32_mask_f32_inplace, "soft_max_f32_mask_f32_inplace", constants); + webgpu_ctx->soft_max_pipelines[0][1][0] = ggml_webgpu_create_pipeline( + webgpu_ctx->device, wgsl_soft_max_f32_mask_f32_sink, "soft_max_f32_mask_f32_sink", constants); + webgpu_ctx->soft_max_pipelines[0][1][1] = ggml_webgpu_create_pipeline( + webgpu_ctx->device, wgsl_soft_max_f32_mask_f32_sink_inplace, "soft_max_f32_mask_f32_sink_inplace", constants); + + // f16 mask (mask_type = 1) + webgpu_ctx->soft_max_pipelines[1][0][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_soft_max_f32_mask_f16, "soft_max_f32_mask_f16", constants); + webgpu_ctx->soft_max_pipelines[1][0][1] = ggml_webgpu_create_pipeline( + webgpu_ctx->device, wgsl_soft_max_f32_mask_f16_inplace, "soft_max_f32_mask_f16_inplace", constants); + webgpu_ctx->soft_max_pipelines[1][1][0] = ggml_webgpu_create_pipeline( + webgpu_ctx->device, wgsl_soft_max_f32_mask_f16_sink, "soft_max_f32_mask_f16_sink", constants); + webgpu_ctx->soft_max_pipelines[1][1][1] = ggml_webgpu_create_pipeline( + webgpu_ctx->device, wgsl_soft_max_f32_mask_f16_sink_inplace, "soft_max_f32_mask_f16_sink_inplace", constants); } static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) { @@ -2125,12 +2337,12 @@ static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, co /* .device = */ dev, /* .context = */ &backend_ctx, }; - return &backend; } static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggml_backend_dev_t dev) { // See GGML Backend Buffer Type Interface section + static struct ggml_backend_buffer_type ggml_backend_webgpu_buffer_type = { /* .iface = */ { /* .get_name = */ ggml_backend_webgpu_buffer_type_get_name, @@ -2295,6 +2507,36 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const case GGML_OP_SOFT_MAX: supports_op = op->type == GGML_TYPE_F32; break; + case GGML_OP_UNARY: + { + const ggml_unary_op UNARY_OP = ggml_get_unary_op(op); + + switch (UNARY_OP) { + case GGML_UNARY_OP_ABS: + case GGML_UNARY_OP_SGN: + case GGML_UNARY_OP_NEG: + case GGML_UNARY_OP_STEP: + case GGML_UNARY_OP_TANH: + case GGML_UNARY_OP_ELU: + case GGML_UNARY_OP_RELU: + case GGML_UNARY_OP_SIGMOID: + case GGML_UNARY_OP_GELU: + case GGML_UNARY_OP_GELU_QUICK: + case GGML_UNARY_OP_SILU: + case GGML_UNARY_OP_HARDSWISH: + case GGML_UNARY_OP_HARDSIGMOID: + case GGML_UNARY_OP_EXP: + case GGML_UNARY_OP_GELU_ERF: + case GGML_UNARY_OP_XIELU: + supports_op = supports_op = + (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type); + break; + default: + break; + } + } + break; + default: break; } @@ -2388,7 +2630,6 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t GGML_ASSERT(ctx->adapter != nullptr); ctx->adapter.GetLimits(&ctx->limits); - ctx->max_wg_size_x = 288; // default value wgpu::AdapterInfo info{}; #ifndef __EMSCRIPTEN__ @@ -2522,6 +2763,7 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t ggml_webgpu_init_glu_pipeline(ctx); ggml_webgpu_init_scale_pipeline(ctx); ggml_webgpu_init_soft_max_pipeline(ctx); + ggml_webgpu_init_unary_pipeline(ctx); #ifdef GGML_WEBGPU_DEBUG // Initialize debug buffers diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py b/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py index ed8068d4..d61df5bb 100755 --- a/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +++ b/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py @@ -19,6 +19,15 @@ def parse_decls(decls_text): return decls +def replace_repl_placeholders(variant, template_map): + for repl, code in variant["REPLS"].items(): + for key, val in template_map.items(): + # Match "key" and avoid matching subsequences using by using \b + code = re.sub(rf'\b{re.escape(str(key))}\b', str(val), code) + variant["REPLS"][repl] = code + return variant + + def replace_placeholders(shader_text, replacements): for key, val in replacements.items(): # Match {{KEY}} literally, where KEY is escaped @@ -71,6 +80,10 @@ def generate_variants(fname, input_dir, output_dir, outfile): decls_map = parse_decls(extract_block(text, "DECLS")) except ValueError: decls_map = {} + try: + templates_map = ast.literal_eval(extract_block(text, "REPL_TEMPLATES")) + except ValueError: + templates_map = {} for fname in sorted(os.listdir(input_dir)): if fname.endswith(".tmpl"): @@ -90,9 +103,11 @@ def generate_variants(fname, input_dir, output_dir, outfile): if key not in decls_map: raise ValueError(f"DECLS key '{key}' not found.") decls_code += decls_map[key] + "\n\n" - final_shader = re.sub(r'\bDECLS\b', decls_code, shader_template) if "REPLS" in variant: + variant = replace_repl_placeholders(variant, templates_map) + final_shader = replace_placeholders(final_shader, variant["REPLS"]) + # second run to expand placeholders in repl_template final_shader = replace_placeholders(final_shader, variant["REPLS"]) final_shader = expand_includes(final_shader, input_dir) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl new file mode 100644 index 00000000..d474ab10 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl @@ -0,0 +1,461 @@ +#define(REPL_TEMPLATES) + +{ + "XIELU_FUNC": "{{MUTATE}}[dst_i] = select(((exp(min(src[src_i], {{TYPE}}(params.eps))) - 1.0) - src[src_i]) * {{TYPE}}(params.alpha_n) + {{TYPE}}(params.beta) * src[src_i], {{TYPE}}(params.alpha_p) * src[src_i] * src[src_i] + {{TYPE}}(params.beta) * src[src_i], src[src_i] > 0.0);", + "ABS_FUNC": "{{MUTATE}}[dst_i] = abs(src[src_i]);", + "SGN_FUNC": "{{MUTATE}}[dst_i] = select({{TYPE}}(select(0.0, -1.0, src[src_i] < 0.0)), {{TYPE}}(1.0), src[src_i] > 0.0);", + "NEG_FUNC": "{{MUTATE}}[dst_i] = -src[src_i];", + "STEP_FUNC": "{{MUTATE}}[dst_i] = {{TYPE}}(select(0.0, 1.0, src[src_i] > 0.0));", + "TANH_FUNC": "{{MUTATE}}[dst_i] = tanh(clamp(src[src_i], -9.010913, 9.010913)); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", + "RELU_FUNC": "{{MUTATE}}[dst_i] = select(0.0, src[src_i], src[src_i] > 0.0);", + "ELU_FUNC": "{{MUTATE}}[dst_i] = select(exp(src[src_i]) - 1.0, src[src_i], src[src_i] > 0.0);", + "HARDSIGMOID_FUNC": "{{MUTATE}}[dst_i] = min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));", + "SIGMOID_FUNC": "{{MUTATE}}[dst_i] = 1.0 / (1.0 + exp(-src[src_i]));", + "SILU_FUNC": "{{MUTATE}}[dst_i] = src[src_i] / (1.0 + exp(-src[src_i]));", + "EXP_FUNC": "{{MUTATE}}[dst_i] = exp(src[src_i]);", + "HARDSWISH_FUNC": "{{MUTATE}}[dst_i] = src[src_i] * min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));", + "GELU_FUNC": "{{MUTATE}}[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(sqrt(2.0 / 3.14159265) * (src[src_i] + 0.044715 * pow(src[src_i], 3.0)), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", + "GELU_QUICK_FUNC": "{{MUTATE}}[dst_i] = src[src_i] * 0.5 * (1.0 + tanh(clamp(0.79788456 * (src[src_i] + 0.044715 * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", + "GELU_ERF_FUNC": "{{MUTATE}}[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(0.79788456 * (src[src_i] + 0.044715 * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" +} + +#end(REPL_TEMPLATES) + +#define(VARIANTS) + +[ + { + "SHADER_NAME": "abs_f32", + "REPLS": { "TYPE": "f32", "FUNC": "ABS_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "abs_f16", + "REPLS": { "TYPE": "f16", "FUNC": "ABS_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "abs_inplace_f32", + "REPLS": { "TYPE": "f32", "FUNC": "ABS_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "abs_inplace_f16", + "REPLS": { "TYPE": "f16", "FUNC": "ABS_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + + { + "SHADER_NAME": "sgn_f32", + "REPLS": { "TYPE": "f32", "FUNC": "SGN_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "sgn_f16", + "REPLS": { "TYPE": "f16", "FUNC": "SGN_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "sgn_inplace_f32", + "REPLS": { "TYPE": "f32", "FUNC": "SGN_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "sgn_inplace_f16", + "REPLS": { "TYPE": "f16", "FUNC": "SGN_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + + { + "SHADER_NAME": "neg_f32", + "REPLS": { "TYPE": "f32", "FUNC": "NEG_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "neg_f16", + "REPLS": { "TYPE": "f16", "FUNC": "NEG_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "neg_inplace_f32", + "REPLS": { "TYPE": "f32", "FUNC": "NEG_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "neg_inplace_f16", + "REPLS": { "TYPE": "f16", "FUNC": "NEG_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + + { + "SHADER_NAME": "step_f32", + "REPLS": { "TYPE": "f32", "FUNC": "STEP_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "step_f16", + "REPLS": { "TYPE": "f16", "FUNC": "STEP_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "step_inplace_f32", + "REPLS": { "TYPE": "f32", "FUNC": "STEP_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "step_inplace_f16", + "REPLS": { "TYPE": "f16", "FUNC": "STEP_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + + { + "SHADER_NAME": "tanh_f32", + "REPLS": { "TYPE": "f32", "FUNC": "TANH_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "tanh_f16", + "REPLS": { "TYPE": "f16", "FUNC": "TANH_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "tanh_inplace_f32", + "REPLS": { "TYPE": "f32", "FUNC": "TANH_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "tanh_inplace_f16", + "REPLS": { "TYPE": "f16", "FUNC": "TANH_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + + { + "SHADER_NAME": "elu_f32", + "REPLS": { "TYPE": "f32", "FUNC": "ELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "elu_f16", + "REPLS": { "TYPE": "f16", "FUNC": "ELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "elu_inplace_f32", + "REPLS": { "TYPE": "f32", "FUNC": "ELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "elu_inplace_f16", + "REPLS": { "TYPE": "f16", "FUNC": "ELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + + { + "SHADER_NAME": "relu_f32", + "REPLS": { "TYPE": "f32", "FUNC": "RELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "relu_f16", + "REPLS": { "TYPE": "f16", "FUNC": "RELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "relu_inplace_f32", + "REPLS": { "TYPE": "f32", "FUNC": "RELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "relu_inplace_f16", + "REPLS": { "TYPE": "f16", "FUNC": "RELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + + { + "SHADER_NAME": "sigmoid_f32", + "REPLS": { "TYPE": "f32", "FUNC": "SIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "sigmoid_f16", + "REPLS": { "TYPE": "f16", "FUNC": "SIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "sigmoid_inplace_f32", + "REPLS": { "TYPE": "f32", "FUNC": "SIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "sigmoid_inplace_f16", + "REPLS": { "TYPE": "f16", "FUNC": "SIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + + { + "SHADER_NAME": "silu_f32", + "REPLS": { "TYPE": "f32", "FUNC": "SILU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "silu_f16", + "REPLS": { "TYPE": "f16", "FUNC": "SILU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "silu_inplace_f32", + "REPLS": { "TYPE": "f32", "FUNC": "SILU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "silu_inplace_f16", + "REPLS": { "TYPE": "f16", "FUNC": "SILU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + + { + "SHADER_NAME": "exp_f32", + "REPLS": { "TYPE": "f32", "FUNC": "EXP_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "exp_f16", + "REPLS": { "TYPE": "f16", "FUNC": "EXP_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "exp_inplace_f32", + "REPLS": { "TYPE": "f32", "FUNC": "EXP_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "exp_inplace_f16", + "REPLS": { "TYPE": "f16", "FUNC": "EXP_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + + { + "SHADER_NAME": "hardsigmoid_f32", + "REPLS": { "TYPE": "f32", "FUNC": "HARDSIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "hardsigmoid_f16", + "REPLS": { "TYPE": "f16", "FUNC": "HARDSIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "hardsigmoid_inplace_f32", + "REPLS": { "TYPE": "f32", "FUNC": "HARDSIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "hardsigmoid_inplace_f16", + "REPLS": { "TYPE": "f16", "FUNC": "HARDSIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + + { + "SHADER_NAME": "hardswish_f32", + "REPLS": { "TYPE": "f32", "FUNC": "HARDSWISH_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "hardswish_f16", + "REPLS": { "TYPE": "f16", "FUNC": "HARDSWISH_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "hardswish_inplace_f32", + "REPLS": { "TYPE": "f32", "FUNC": "HARDSWISH_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "hardswish_inplace_f16", + "REPLS": { "TYPE": "f16", "FUNC": "HARDSWISH_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + + { + "SHADER_NAME": "gelu_f32", + "REPLS": { "TYPE": "f32", "FUNC": "GELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "gelu_f16", + "REPLS": { "TYPE": "f16", "FUNC": "GELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "gelu_inplace_f32", + "REPLS": { "TYPE": "f32", "FUNC": "GELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "gelu_inplace_f16", + "REPLS": { "TYPE": "f16", "FUNC": "GELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + + { + "SHADER_NAME": "gelu_quick_f32", + "REPLS": { "TYPE": "f32", "FUNC": "GELU_QUICK_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "gelu_quick_f16", + "REPLS": { "TYPE": "f16", "FUNC": "GELU_QUICK_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "gelu_quick_inplace_f32", + "REPLS": { "TYPE": "f32", "FUNC": "GELU_QUICK_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "gelu_quick_inplace_f16", + "REPLS": { "TYPE": "f16", "FUNC": "GELU_QUICK_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + + { + "SHADER_NAME": "xielu_f32", + "REPLS": { "TYPE": "f32", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: f32, alpha_p: f32, beta: f32, eps: f32", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "xielu_f16", + "REPLS": { "TYPE": "f16", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: f32, alpha_p: f32, beta: f32, eps: f32", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "xielu_inplace_f32", + "REPLS": { "TYPE": "f32", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: f32, alpha_p: f32, beta: f32, eps: f32", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "xielu_inplace_f16", + "REPLS": { "TYPE": "f16", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: f32, alpha_p: f32, beta: f32, eps: f32", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "gelu_erf_f32", + "REPLS": { "TYPE": "f32", "FUNC": "GELU_ERF_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "gelu_erf_f16", + "REPLS": { "TYPE": "f16", "FUNC": "GELU_ERF_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "gelu_erf_inplace_f32", + "REPLS": { "TYPE": "f32", "FUNC": "GELU_ERF_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "gelu_erf_inplace_f16", + "REPLS": { "TYPE": "f16", "FUNC": "GELU_ERF_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + } +] + +#end(VARIANTS) + +#define(DECLS) + +#decl(INPLACE) + +@group(0) @binding(1) +var params: Params; + +#enddecl(INPLACE) + +#decl(NOT_INPLACE) + +@group(0) @binding(1) +var dst: array<{{TYPE}}>; + +@group(0) @binding(2) +var params: Params; + +#enddecl(NOT_INPLACE) + +#end(DECLS) + +#define(SHADER) + +enable f16; + +fn update(dst_i: u32, src_i: u32) { + {{FUNC}} +} + +@group(0) @binding(0) +var src: array<{{TYPE}}>; + +DECLS + +struct Params { + ne: u32, // total number of elements + offset_src: u32, // in elements + offset_dst: u32, // in elements + + // Strides (in elements) — may be permuted + stride_src0: u32, + stride_src1: u32, + stride_src2: u32, + stride_src3: u32, + + stride_dst0: u32, + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + // Logical shapes + src_ne0: u32, + src_ne1: u32, + src_ne2: u32, + + dst_ne0: u32, + dst_ne1: u32, + dst_ne2: u32, + + {{EXT_PARAMS}} +}; + +override wg_size: u32; +@compute @workgroup_size(wg_size) +fn main(@builtin(global_invocation_id) gid: vec3) { + if (gid.x >= params.ne) { + return; + } + + var i = gid.x; + let i3 = i / (params.src_ne2 * params.src_ne1 * params.src_ne0); + i = i % (params.src_ne2 * params.src_ne1 * params.src_ne0); + let i2 = i / (params.src_ne1 * params.src_ne0); + i = i % (params.src_ne1 * params.src_ne0); + let i1 = i / params.src_ne0; + let i0 = i % params.src_ne0; + + var j = gid.x; + let j3 = j / (params.dst_ne2 * params.dst_ne1 * params.dst_ne0); + j = j % (params.dst_ne2 * params.dst_ne1 * params.dst_ne0); + let j2 = j / (params.dst_ne1 * params.dst_ne0); + j = j % (params.dst_ne1 * params.dst_ne0); + let j1 = j / params.dst_ne0; + let j0 = j % params.dst_ne0; + + let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 + + i2 * params.stride_src2 + i3 * params.stride_src3; + + let dst_idx = j0 * params.stride_dst0 + j1 * params.stride_dst1 + + j2 * params.stride_dst2 + j3 * params.stride_dst3; + + + update(params.offset_dst + dst_idx, params.offset_src + src_idx); +} + +#end(SHADER) +