From ebff8f9db90ee9a9754c9d8a51c8bf6650c40157 Mon Sep 17 00:00:00 2001 From: Vishal Singh Date: Sat, 6 Dec 2025 21:43:33 +0530 Subject: [PATCH] ggml-zendnn : add ZenDNN backend for AMD CPUs (llama/17690) * ggml-zennn: add ZenDNN backend support * ggml-zendnn : address ZenDNN backend review fixes and suggestions * docs : apply blockquote syntax to ZenDNN docs --------- Co-authored-by: Manoj Kumar --- ggml/CMakeLists.txt | 4 + ggml/include/ggml-zendnn.h | 22 ++ ggml/src/CMakeLists.txt | 1 + ggml/src/ggml-backend-reg.cpp | 8 + ggml/src/ggml-zendnn/CMakeLists.txt | 92 ++++++ ggml/src/ggml-zendnn/ggml-zendnn.cpp | 466 +++++++++++++++++++++++++++ 6 files changed, 593 insertions(+) create mode 100644 ggml/include/ggml-zendnn.h create mode 100644 ggml/src/ggml-zendnn/CMakeLists.txt create mode 100644 ggml/src/ggml-zendnn/ggml-zendnn.cpp diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 0ccd9019..6b69ad82 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -253,6 +253,9 @@ option(GGML_HEXAGON "ggml: enable Hexagon backend" # toolchain for vulkan-shaders-gen set (GGML_VULKAN_SHADERS_GEN_TOOLCHAIN "" CACHE FILEPATH "ggml: toolchain file for vulkan-shaders-gen") +option(GGML_ZENDNN "ggml: use ZenDNN" OFF) +option(ZENDNN_ROOT "ggml: path to ZenDNN installation" "") + # extra artifacts option(GGML_BUILD_TESTS "ggml: build tests" ${GGML_STANDALONE}) option(GGML_BUILD_EXAMPLES "ggml: build examples" ${GGML_STANDALONE}) @@ -314,6 +317,7 @@ set(GGML_PUBLIC_HEADERS include/ggml-sycl.h include/ggml-vulkan.h include/ggml-webgpu.h + include/ggml-zendnn.h include/gguf.h) set_target_properties(ggml PROPERTIES PUBLIC_HEADER "${GGML_PUBLIC_HEADERS}") diff --git a/ggml/include/ggml-zendnn.h b/ggml/include/ggml-zendnn.h new file mode 100644 index 00000000..a30a3a98 --- /dev/null +++ b/ggml/include/ggml-zendnn.h @@ -0,0 +1,22 @@ +#pragma once + +#include "ggml-backend.h" +#include "ggml.h" + +#ifdef __cplusplus +extern "C" { +#endif + +// backend API +GGML_BACKEND_API ggml_backend_t ggml_backend_zendnn_init(void); + +GGML_BACKEND_API bool ggml_backend_is_zendnn(ggml_backend_t backend); + +// number of threads used for zendnn operations +GGML_BACKEND_API void ggml_backend_zendnn_set_n_threads(ggml_backend_t backend_zendnn, int n_threads); + +GGML_BACKEND_API ggml_backend_reg_t ggml_backend_zendnn_reg(void); + +#ifdef __cplusplus +} +#endif diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index 98606e9c..4c04c330 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -440,6 +440,7 @@ ggml_add_backend(WebGPU) ggml_add_backend(zDNN) ggml_add_backend(OpenCL) ggml_add_backend(Hexagon) +ggml_add_backend(ZenDNN) foreach (target ggml-base ggml) target_include_directories(${target} PUBLIC $ $) diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp index 88e6dc45..4181a714 100644 --- a/ggml/src/ggml-backend-reg.cpp +++ b/ggml/src/ggml-backend-reg.cpp @@ -73,6 +73,10 @@ #include "ggml-cann.h" #endif +#ifdef GGML_USE_ZENDNN +#include "ggml-zendnn.h" +#endif + // disable C++17 deprecation warning for std::codecvt_utf8 #if defined(__clang__) # pragma clang diagnostic push @@ -203,6 +207,9 @@ struct ggml_backend_registry { #ifdef GGML_USE_OPENCL register_backend(ggml_backend_opencl_reg()); #endif +#ifdef GGML_USE_ZENDNN + register_backend(ggml_backend_zendnn_reg()); +#endif #ifdef GGML_USE_HEXAGON register_backend(ggml_backend_hexagon_reg()); #endif @@ -605,6 +612,7 @@ void ggml_backend_load_all_from_path(const char * dir_path) { #endif ggml_backend_load_best("blas", silent, dir_path); + ggml_backend_load_best("zendnn", silent, dir_path); ggml_backend_load_best("cann", silent, dir_path); ggml_backend_load_best("cuda", silent, dir_path); ggml_backend_load_best("hip", silent, dir_path); diff --git a/ggml/src/ggml-zendnn/CMakeLists.txt b/ggml/src/ggml-zendnn/CMakeLists.txt new file mode 100644 index 00000000..bdbfc743 --- /dev/null +++ b/ggml/src/ggml-zendnn/CMakeLists.txt @@ -0,0 +1,92 @@ +ggml_add_backend_library(ggml-zendnn + ggml-zendnn.cpp) + +# Get ZenDNN path +if (NOT DEFINED ZENDNN_ROOT OR ZENDNN_ROOT STREQUAL "") + set(ZENDNN_ROOT "$ENV{ZENDNN_ROOT}") +endif() + +# Check if path is still empty or OFF +if (NOT ZENDNN_ROOT OR ZENDNN_ROOT STREQUAL "" OR ZENDNN_ROOT STREQUAL "OFF") + message(STATUS "ZENDNN_ROOT not set. Automatically downloading and building ZenDNN...") + message(STATUS "This will take several minutes on first build...") + + include(ExternalProject) + + set(ZENDNN_PREFIX ${CMAKE_BINARY_DIR}/_deps/zendnn-prefix) + set(ZENDNN_SOURCE_DIR ${ZENDNN_PREFIX}/src/zendnn) + set(ZENDNN_BUILD_DIR ${ZENDNN_PREFIX}/build) + set(ZENDNN_INSTALL_DIR ${ZENDNN_BUILD_DIR}/install) + + ExternalProject_Add( + zendnn + GIT_REPOSITORY https://github.com/amd/ZenDNN.git + GIT_TAG zendnnl + PREFIX ${ZENDNN_PREFIX} + SOURCE_DIR ${ZENDNN_SOURCE_DIR} + BINARY_DIR ${ZENDNN_BUILD_DIR} + CMAKE_ARGS + -DCMAKE_BUILD_TYPE=Release + -DCMAKE_INSTALL_PREFIX=${ZENDNN_INSTALL_DIR} + -DZENDNNL_BUILD_EXAMPLES=OFF + -DZENDNNL_BUILD_DOXYGEN=OFF + -DZENDNNL_BUILD_GTEST=OFF + -DZENDNNL_BUILD_BENCHDNN=OFF + # Enable ALL matmul algorithm backends + -DZENDNNL_DEPENDS_AOCLDLP=ON + -DZENDNNL_DEPENDS_ONEDNN=ON + -DZENDNNL_DEPENDS_LIBXSMM=ON + BUILD_COMMAND ${CMAKE_COMMAND} --build ${ZENDNN_BUILD_DIR} --target zendnnl + INSTALL_COMMAND ${CMAKE_COMMAND} --build ${ZENDNN_BUILD_DIR} --target install + BUILD_ALWAYS OFF + LOG_DOWNLOAD ON + LOG_CONFIGURE ON + LOG_BUILD ON + LOG_INSTALL ON + ) + + # Add dependency so ZenDNN builds before our library + add_dependencies(ggml-zendnn zendnn) + + # Set ZENDNN_ROOT to the installation directory + set(ZENDNN_ROOT ${ZENDNN_INSTALL_DIR}) + + message(STATUS "ZenDNN will be built to: ${ZENDNN_ROOT}") +else() + message(STATUS "Using custom ZenDNN installation at: ${ZENDNN_ROOT}") +endif() + +# ZenDNN headers + libs +target_include_directories(ggml-zendnn PRIVATE + ${ZENDNN_ROOT}/zendnnl/include + ${ZENDNN_ROOT}/deps/aocldlp/include + ${ZENDNN_ROOT}/deps/aoclutils/include + ${ZENDNN_ROOT}/deps/json/include + ${ZENDNN_ROOT}/deps/libxsmm/include + ${ZENDNN_ROOT}/deps/onednn/include +) + +target_link_directories(ggml-zendnn PRIVATE + ${ZENDNN_ROOT}/zendnnl/lib + ${ZENDNN_ROOT}/deps/aocldlp/lib + ${ZENDNN_ROOT}/deps/aoclutils/lib + ${ZENDNN_ROOT}/deps/libxsmm/lib + ${ZENDNN_ROOT}/deps/onednn/lib +) + +target_link_libraries(ggml-zendnn PRIVATE + zendnnl_archive # ZenDNN main + aocl-dlp # AOCL libraries + aoclutils + au_cpuid + dnnl # OneDNN + xsmm # libxsmm small matrix math + xsmmext + xsmmnoblas + m + pthread +) + +if (GGML_OPENMP) + target_link_libraries(ggml-zendnn PRIVATE OpenMP::OpenMP_CXX) +endif() diff --git a/ggml/src/ggml-zendnn/ggml-zendnn.cpp b/ggml/src/ggml-zendnn/ggml-zendnn.cpp new file mode 100644 index 00000000..fd07f983 --- /dev/null +++ b/ggml/src/ggml-zendnn/ggml-zendnn.cpp @@ -0,0 +1,466 @@ +#include "ggml-zendnn.h" + +#include "ggml-backend-impl.h" +#include "ggml-impl.h" +#include "ggml-cpu.h" +#include "zendnnl.hpp" + +#include + + +struct ggml_backend_zendnn_context { + int n_threads = GGML_DEFAULT_N_THREADS; + std::unique_ptr work_data; + size_t work_size = 0; +}; + +template +zendnnl::common::data_type_t ggml_to_zendnn_type() { + if constexpr (std::is_same_v) { + return zendnnl::common::data_type_t::f32; + } else if constexpr (std::is_same_v) { + return zendnnl::common::data_type_t::bf16; + } else { + return zendnnl::common::data_type_t::none; + } +} + +/** + * ZenDNN matmul: computes C = B * A. + * + * - A: weights, shape (k, m), column-major (each column is a weight vector for one output). + * - B: input, shape (n, k), row-major (each row is an input sample). + * - C: output, shape (n, m), row-major. + * + * Dimensions: + * m = output features (columns of C, columns of A) + * n = batch size (rows of C, rows of B) + * k = inner dimension (columns of B, rows of A) + */ +template +static bool ggml_zendnn_matmul(ggml_backend_zendnn_context * ctx, int64_t m, int64_t n, int64_t k, + const TA * A, int64_t lda, const TB * B, int64_t ldb, TC * C, + int64_t ldc) { + + zendnnl::lowoha::lowoha_params params; + params.dtypes.src = ggml_to_zendnn_type(); + params.dtypes.wei = ggml_to_zendnn_type(); + params.dtypes.dst = ggml_to_zendnn_type(); + params.num_threads = ctx->n_threads; + + zendnnl::lowoha::status_t status = zendnnl::lowoha::matmul_direct( + 'r', false, true, // row-major, don't transpose B, transpose A (because it's column-major) + n, // M: rows of B and C + m, // N: cols of A^T and C + k, // K: cols of B, rows of A + 1.0f, // alpha + B, ldb, // src: B[n,k] + A, lda, // weight: A[k,m] column-major (transposed) + nullptr, // bias + 0.0f, // beta + C, ldc, // output C[n,m] + true, // is_weights_const + {}, // batch_params + params // params + ); + + if (status != zendnnl::lowoha::status_t::success) { + GGML_LOG_ERROR("%s, ZenDNN matmul failed: status=%d\n", __func__, static_cast(status)); + return false; + } + return true; +} + +static bool ggml_zendnn_sgemm(ggml_backend_zendnn_context * ctx, int64_t m, int64_t n, int64_t k, + const void * A, int64_t lda, const void * B, int64_t ldb, void * C, + int64_t ldc, int Atype, int Btype, int Ctype) { + + assert(m >= 0); + assert(n >= 0); + assert(k >= 0); + assert(lda >= k); + assert(ldb >= k); + assert(ldc >= m); + + // categorize types + switch (Atype) { + case GGML_TYPE_F32: + if (Btype != GGML_TYPE_F32 || Ctype != GGML_TYPE_F32) + return false; + return ggml_zendnn_matmul( + ctx, m, n, k, + (const float *)A, lda, + (const float *)B, ldb, + (float *)C, ldc); + case GGML_TYPE_BF16: + if (Btype != GGML_TYPE_BF16) + return false; + if (Ctype == GGML_TYPE_BF16) + return ggml_zendnn_matmul( + ctx, m, n, k, + (const ggml_bf16_t *)A, lda, + (const ggml_bf16_t *)B, ldb, + (ggml_bf16_t *)C, ldc); + if (Ctype == GGML_TYPE_F32) + return ggml_zendnn_matmul( + ctx, m, n, k, + (const ggml_bf16_t *)A, lda, + (const ggml_bf16_t *)B, ldb, + (float *)C, ldc); + return false; + default: + return false; // unsupported type + } +} + +static void ggml_zendnn_compute_forward_mul_mat( + ggml_backend_zendnn_context * ctx, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; // weights + const ggml_tensor * src1 = dst->src[1]; // inputs + + GGML_TENSOR_BINARY_OP_LOCALS + + ggml_type const vec_dot_type = ggml_get_type_traits_cpu(src0->type)->vec_dot_type; + ggml_from_float_t const from_float = ggml_get_type_traits_cpu(vec_dot_type)->from_float; + + GGML_ASSERT(ne0 == ne01); + GGML_ASSERT(ne1 == ne11); + GGML_ASSERT(ne2 == ne12); + GGML_ASSERT(ne3 == ne13); + + // we don't support permuted src0 or src1 + GGML_ASSERT(nb00 == ggml_type_size(src0->type)); + GGML_ASSERT(nb10 == ggml_type_size(src1->type)); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + // broadcast factors + const int64_t r2 = ne12/ne02; + const int64_t r3 = ne13/ne03; + + void * work_data = ctx->work_data.get(); + if (src1->type != vec_dot_type) { + const size_t nbw1 = ggml_row_size(vec_dot_type, ne10); + const size_t nbw2 = nbw1 * ne11; + const size_t nbw3 = nbw2 * ne12; + const size_t desired_wsize = ne13 * nbw3; + if (ctx->work_size < desired_wsize) { + ctx->work_data.reset(new char[desired_wsize]); + ctx->work_size = desired_wsize; + } + work_data = ctx->work_data.get(); + + // #pragma omp parallel for num_threads(ctx->n_threads) + #pragma omp parallel for collapse(3) num_threads(ctx->n_threads) schedule(static) + for (int64_t i13 = 0; i13 < ne13; ++i13) { + for (int64_t i12 = 0; i12 < ne12; ++i12) { + for (int64_t i11 = 0; i11 < ne11; ++i11) { + const float * src1_f32 = (float *)((char *)src1->data + i11*nb11 + i12*nb12 + i13*nb13); + void * src1_conv = (char *)work_data + i11*nbw1 + i12*nbw2 + i13*nbw3; + from_float(src1_f32, src1_conv, ne10); + } + } + } + } + + for (int64_t i13 = 0; i13 < ne13; i13++) { + for (int64_t i12 = 0; i12 < ne12; i12++) { + const void* wdata = src1->type == vec_dot_type ? src1->data : work_data; + const size_t row_size = ggml_row_size(vec_dot_type, ne10); + if (!ggml_zendnn_sgemm(ctx, + ne01, // m + ne11, // n + ne10, // k + static_cast(src0->data) + (i12/r2)*nb02 + (i13/r3)*nb03, + ne00, // lda + static_cast(wdata) + (i12*ne11 + i13*ne12*ne11)*row_size, + ne10, // ldb + static_cast(dst->data) + i12*nb2 + i13*nb3, + ne01, // ldc + src0->type, + vec_dot_type, + dst->type)) + GGML_ABORT("%s: ZenDNN sgemm failed\n", __func__); + } + } +} + +// backend interface + +static const char * ggml_backend_zendnn_get_name(ggml_backend_t backend) { + return "ZenDNN"; + + GGML_UNUSED(backend); +} + +static void ggml_backend_zendnn_free(ggml_backend_t backend) { + ggml_backend_zendnn_context * ctx = (ggml_backend_zendnn_context *)backend->context; + delete ctx; + delete backend; +} + +static ggml_status ggml_backend_zendnn_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { + ggml_backend_zendnn_context * ctx = (ggml_backend_zendnn_context *)backend->context; + + for (int i = 0; i < cgraph->n_nodes; i++) { + struct ggml_tensor * node = cgraph->nodes[i]; + + switch (node->op) { + case GGML_OP_MUL_MAT: + ggml_zendnn_compute_forward_mul_mat(ctx, node); + break; + case GGML_OP_NONE: + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_PERMUTE: + case GGML_OP_TRANSPOSE: + break; + + default: + GGML_ABORT("%s: unsupported op %s\n", __func__, ggml_op_desc(node)); + } + } + + return GGML_STATUS_SUCCESS; + + GGML_UNUSED(backend); +} + +static struct ggml_backend_i ggml_backend_zendnn_i = { + /* .get_name = */ ggml_backend_zendnn_get_name, + /* .free = */ ggml_backend_zendnn_free, + /* .set_tensor_async = */ NULL, + /* .get_tensor_async = */ NULL, + /* .cpy_tensor_async = */ NULL, + /* .synchronize = */ NULL, + /* .graph_plan_create = */ NULL, + /* .graph_plan_free = */ NULL, + /* .graph_plan_update = */ NULL, + /* .graph_plan_compute = */ NULL, + /* .graph_compute = */ ggml_backend_zendnn_graph_compute, + /* .event_record = */ NULL, + /* .event_wait = */ NULL, + /* .graph_optimize = */ NULL, +}; + +static ggml_guid_t ggml_backend_zendnn_guid(void) { + static const char * guid_str = "AMD-ZENDNN-ACCEL"; + return reinterpret_cast(const_cast(guid_str)); +} + +ggml_backend_t ggml_backend_zendnn_init(void) { + ggml_backend_zendnn_context * ctx = new ggml_backend_zendnn_context; + + ggml_backend_t backend = new ggml_backend { + /* .guid = */ ggml_backend_zendnn_guid(), + /* .iface = */ ggml_backend_zendnn_i, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_zendnn_reg(), 0), + /* .context = */ ctx, + }; + + return backend; +} + +bool ggml_backend_is_zendnn(ggml_backend_t backend) { + return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_zendnn_guid()); +} + +void ggml_backend_zendnn_set_n_threads(ggml_backend_t backend_zendnn, int n_threads) { + GGML_ASSERT(ggml_backend_is_zendnn(backend_zendnn)); + + ggml_backend_zendnn_context * ctx = (ggml_backend_zendnn_context *)backend_zendnn->context; + ctx->n_threads = n_threads; +} + +// device interface +static const char * ggml_backend_zendnn_device_get_name(ggml_backend_dev_t dev) { + return "ZenDNN"; + + GGML_UNUSED(dev); +} +/** + * ZenDNN is AMD's performance library providing optimized primitives and implementations + * for deep learning workloads on AMD CPUs. It targets improved performance for common + * neural network operations on AMD architectures. For more information, see: + * https://www.amd.com/en/developer/zendnn.html + */ +static const char * ggml_backend_zendnn_device_get_description(ggml_backend_dev_t dev) { + return "ZenDNN: AMD optimized primitives backend for GGML (optimized for AMD CPUs)"; + + GGML_UNUSED(dev); +} + +static void ggml_backend_zendnn_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { + *free = 0; + *total = 0; + + GGML_UNUSED(dev); +} + +static enum ggml_backend_dev_type ggml_backend_zendnn_device_get_type(ggml_backend_dev_t dev) { + return GGML_BACKEND_DEVICE_TYPE_ACCEL; + + GGML_UNUSED(dev); +} + +static void ggml_backend_zendnn_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) { + props->name = ggml_backend_zendnn_device_get_name(dev); + props->description = ggml_backend_zendnn_device_get_description(dev); + props->type = ggml_backend_zendnn_device_get_type(dev); + ggml_backend_zendnn_device_get_memory(dev, &props->memory_free, &props->memory_total); + props->caps = { + /* .async = */ false, + /* .host_buffer = */ false, + /* .buffer_from_host_ptr = */ true, + /* .events = */ false + }; +} + +static ggml_backend_t ggml_backend_zendnn_device_init_backend(ggml_backend_dev_t dev, const char * params) { + ggml_backend_t backend = ggml_backend_zendnn_init(); + if (backend == NULL) { + GGML_LOG_ERROR("%s: error: failed to initialize ZenDNN backend\n", __func__); + return NULL; + } + + return backend; + + GGML_UNUSED(dev); + GGML_UNUSED(params); +} + +static ggml_backend_buffer_type_t ggml_backend_zendnn_device_get_buffer_type(ggml_backend_dev_t dev) { + return ggml_backend_cpu_buffer_type(); + + GGML_UNUSED(dev); +} + +static ggml_backend_buffer_t ggml_backend_zendnn_device_buffer_from_host_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) { + return ggml_backend_cpu_buffer_from_ptr(ptr, size); + + GGML_UNUSED(dev); + GGML_UNUSED(max_tensor_size); +} + +static bool ggml_backend_zendnn_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { + switch (op->op) { + case GGML_OP_NONE: + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_PERMUTE: + case GGML_OP_TRANSPOSE: + return true; + + case GGML_OP_MUL_MAT: + { + const ggml_tensor * weights = op->src[0]; + const ggml_tensor * inputs = op->src[1]; + + const int64_t ne10 = inputs->ne[0]; + const int64_t ne0 = op->ne[0]; + const int64_t ne1 = op->ne[1]; + + const int64_t min_batch = 1; + if (!ggml_is_contiguous(weights) || !ggml_is_contiguous(inputs) || + ne0 < min_batch || ne1 < min_batch || ne10 < min_batch) { + return false; + } + switch (weights->type) { + case GGML_TYPE_F32: + case GGML_TYPE_BF16: + return true; + default: + return false; + } + } break; + + default: + return false; + } + + GGML_UNUSED(dev); +} + +static bool ggml_backend_zendnn_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { + return ggml_backend_buft_is_host(buft); + + GGML_UNUSED(dev); +} + +static const struct ggml_backend_device_i ggml_backend_zendnn_device_i = { + /* .get_name = */ ggml_backend_zendnn_device_get_name, + /* .get_description = */ ggml_backend_zendnn_device_get_description, + /* .get_memory = */ ggml_backend_zendnn_device_get_memory, + /* .get_type = */ ggml_backend_zendnn_device_get_type, + /* .get_props = */ ggml_backend_zendnn_device_get_props, + /* .init_backend = */ ggml_backend_zendnn_device_init_backend, + /* .get_buffer_type = */ ggml_backend_zendnn_device_get_buffer_type, + /* .get_host_buffer_type = */ NULL, + /* .buffer_from_host_ptr = */ ggml_backend_zendnn_device_buffer_from_host_ptr, + /* .supports_op = */ ggml_backend_zendnn_device_supports_op, + /* .supports_buft = */ ggml_backend_zendnn_device_supports_buft, + /* .offload_op = */ NULL, + /* .event_new = */ NULL, + /* .event_free = */ NULL, + /* .event_synchronize = */ NULL, +}; + +// backend reg interface +static const char * ggml_backend_zendnn_reg_get_name(ggml_backend_reg_t reg) { + return "ZenDNN"; + + GGML_UNUSED(reg); +} + +static size_t ggml_backend_zendnn_reg_get_device_count(ggml_backend_reg_t reg) { + return 1; + + GGML_UNUSED(reg); +} + +static ggml_backend_dev_t ggml_backend_zendnn_reg_get_device(ggml_backend_reg_t reg, size_t index) { + GGML_ASSERT(index == 0); + + static ggml_backend_device ggml_backend_zendnn_device = { + /* .iface = */ ggml_backend_zendnn_device_i, + /* .reg = */ reg, + /* .context = */ nullptr, + }; + + return &ggml_backend_zendnn_device; +} + +static void * ggml_backend_zendnn_get_proc_address(ggml_backend_reg_t reg, const char * name) { + if (std::strcmp(name, "ggml_backend_set_n_threads") == 0) { + return (void *) ggml_backend_zendnn_set_n_threads; + } + return NULL; + + GGML_UNUSED(reg); + GGML_UNUSED(name); +} + +static const struct ggml_backend_reg_i ggml_backend_zendnn_reg_i = { + /* .get_name = */ ggml_backend_zendnn_reg_get_name, + /* .get_device_count = */ ggml_backend_zendnn_reg_get_device_count, + /* .get_device = */ ggml_backend_zendnn_reg_get_device, + /* .get_proc_address = */ ggml_backend_zendnn_get_proc_address, +}; + +ggml_backend_reg_t ggml_backend_zendnn_reg(void) { + static struct ggml_backend_reg ggml_backend_zendnn_reg = { + /* .api_version = */ GGML_BACKEND_API_VERSION, + /* .iface = */ ggml_backend_zendnn_reg_i, + /* .context = */ NULL, + }; + + return &ggml_backend_zendnn_reg; +} + +GGML_BACKEND_DL_IMPL(ggml_backend_zendnn_reg)