sycl: add missing BF16 conversion support for Intel oneAPI (llama/17780)

* sycl: add missing BF16 conversion support for Intel oneAPI

* Fix Line 645: Trailing whitespace
This commit is contained in:
Law Po Ying 2025-12-07 09:18:18 +08:00 committed by Georgi Gerganov
parent 898f876fe2
commit 447ef8633b
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
1 changed files with 19 additions and 0 deletions

View File

@ -2,6 +2,13 @@
#include "dequantize.hpp" #include "dequantize.hpp"
#include "presets.hpp" #include "presets.hpp"
#if defined(__INTEL_LLVM_COMPILER)
#if __has_include(<sycl/ext/oneapi/bfloat16.hpp>)
#include <sycl/ext/oneapi/bfloat16.hpp>
#define GGML_SYCL_HAS_BF16
#endif
#endif
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t> template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
static void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, static void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k,
const sycl::nd_item<3> &item_ct1) { const sycl::nd_item<3> &item_ct1) {
@ -566,6 +573,10 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) {
return dequantize_row_iq4_nl_sycl; return dequantize_row_iq4_nl_sycl;
case GGML_TYPE_F32: case GGML_TYPE_F32:
return convert_unary_sycl<float>; return convert_unary_sycl<float>;
#ifdef GGML_SYCL_HAS_BF16
case GGML_TYPE_BF16:
return convert_unary_sycl<sycl::ext::oneapi::bfloat16>;
#endif
default: default:
return nullptr; return nullptr;
} }
@ -627,6 +638,10 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) {
return dequantize_row_iq4_nl_sycl; return dequantize_row_iq4_nl_sycl;
case GGML_TYPE_F16: case GGML_TYPE_F16:
return convert_unary_sycl<sycl::half>; return convert_unary_sycl<sycl::half>;
#ifdef GGML_SYCL_HAS_BF16
case GGML_TYPE_BF16:
return convert_unary_sycl<sycl::ext::oneapi::bfloat16>;
#endif
default: default:
return nullptr; return nullptr;
} }
@ -636,6 +651,10 @@ to_fp16_nc_sycl_t get_to_fp16_nc_sycl(ggml_type type) {
switch (type) { switch (type) {
case GGML_TYPE_F32: case GGML_TYPE_F32:
return convert_unary_nc_sycl<float>; return convert_unary_nc_sycl<float>;
#ifdef GGML_SYCL_HAS_BF16
case GGML_TYPE_BF16:
return convert_unary_nc_sycl<sycl::ext::oneapi::bfloat16>;
#endif
default: default:
return nullptr; return nullptr;
} }