ggml update to b6840 (#12791)

This commit is contained in:
Daniel Hiltgen 2025-11-06 10:19:22 -08:00 committed by GitHub
parent c4ba257c64
commit 544b6739dd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
103 changed files with 3644 additions and 1215 deletions

View File

@ -1,6 +1,6 @@
UPSTREAM=https://github.com/ggml-org/llama.cpp.git UPSTREAM=https://github.com/ggml-org/llama.cpp.git
WORKDIR=llama/vendor WORKDIR=llama/vendor
FETCH_HEAD=7049736b2dd9011bf819e298b844ebbc4b5afdc9 FETCH_HEAD=3cfa9c3f125763305b4226bc032f1954f08990dc
.PHONY: help .PHONY: help
help: help:

2
llama/build-info.cpp generated vendored
View File

@ -1,4 +1,4 @@
int LLAMA_BUILD_NUMBER = 0; int LLAMA_BUILD_NUMBER = 0;
char const *LLAMA_COMMIT = "7049736b2dd9011bf819e298b844ebbc4b5afdc9"; char const *LLAMA_COMMIT = "3cfa9c3f125763305b4226bc032f1954f08990dc";
char const *LLAMA_COMPILER = ""; char const *LLAMA_COMPILER = "";
char const *LLAMA_BUILD_TARGET = ""; char const *LLAMA_BUILD_TARGET = "";

View File

@ -41,9 +41,9 @@ static std::string build_repetition(const std::string & item_rule, int min_items
return result; return result;
} }
static void _build_min_max_int(int min_value, int max_value, std::stringstream & out, int decimals_left = 16, bool top_level = true) { static void _build_min_max_int(int64_t min_value, int64_t max_value, std::stringstream & out, int decimals_left = 16, bool top_level = true) {
auto has_min = min_value != std::numeric_limits<int>::min(); auto has_min = min_value != std::numeric_limits<int64_t>::min();
auto has_max = max_value != std::numeric_limits<int>::max(); auto has_max = max_value != std::numeric_limits<int64_t>::max();
auto digit_range = [&](char from, char to) { auto digit_range = [&](char from, char to) {
out << "["; out << "[";
@ -159,7 +159,7 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
if (has_min) { if (has_min) {
if (min_value < 0) { if (min_value < 0) {
out << "\"-\" ("; out << "\"-\" (";
_build_min_max_int(std::numeric_limits<int>::min(), -min_value, out, decimals_left, /* top_level= */ false); _build_min_max_int(std::numeric_limits<int64_t>::min(), -min_value, out, decimals_left, /* top_level= */ false);
out << ") | [0] | [1-9] "; out << ") | [0] | [1-9] ";
more_digits(0, decimals_left - 1); more_digits(0, decimals_left - 1);
} else if (min_value == 0) { } else if (min_value == 0) {
@ -194,7 +194,7 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
} }
digit_range(c, c); digit_range(c, c);
out << " ("; out << " (";
_build_min_max_int(std::stoi(min_s.substr(1)), std::numeric_limits<int>::max(), out, less_decimals, /* top_level= */ false); _build_min_max_int(std::stoll(min_s.substr(1)), std::numeric_limits<int64_t>::max(), out, less_decimals, /* top_level= */ false);
out << ")"; out << ")";
if (c < '9') { if (c < '9') {
out << " | "; out << " | ";
@ -216,7 +216,7 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
_build_min_max_int(0, max_value, out, decimals_left, /* top_level= */ true); _build_min_max_int(0, max_value, out, decimals_left, /* top_level= */ true);
} else { } else {
out << "\"-\" ("; out << "\"-\" (";
_build_min_max_int(-max_value, std::numeric_limits<int>::max(), out, decimals_left, /* top_level= */ false); _build_min_max_int(-max_value, std::numeric_limits<int64_t>::max(), out, decimals_left, /* top_level= */ false);
out << ")"; out << ")";
} }
return; return;
@ -925,17 +925,17 @@ public:
int max_len = schema.contains("maxLength") ? schema["maxLength"].get<int>() : std::numeric_limits<int>::max(); int max_len = schema.contains("maxLength") ? schema["maxLength"].get<int>() : std::numeric_limits<int>::max();
return _add_rule(rule_name, "\"\\\"\" " + build_repetition(char_rule, min_len, max_len) + " \"\\\"\" space"); return _add_rule(rule_name, "\"\\\"\" " + build_repetition(char_rule, min_len, max_len) + " \"\\\"\" space");
} else if (schema_type == "integer" && (schema.contains("minimum") || schema.contains("exclusiveMinimum") || schema.contains("maximum") || schema.contains("exclusiveMaximum"))) { } else if (schema_type == "integer" && (schema.contains("minimum") || schema.contains("exclusiveMinimum") || schema.contains("maximum") || schema.contains("exclusiveMaximum"))) {
int min_value = std::numeric_limits<int>::min(); int64_t min_value = std::numeric_limits<int64_t>::min();
int max_value = std::numeric_limits<int>::max(); int64_t max_value = std::numeric_limits<int64_t>::max();
if (schema.contains("minimum")) { if (schema.contains("minimum")) {
min_value = schema["minimum"].get<int>(); min_value = schema["minimum"].get<int64_t>();
} else if (schema.contains("exclusiveMinimum")) { } else if (schema.contains("exclusiveMinimum")) {
min_value = schema["exclusiveMinimum"].get<int>() + 1; min_value = schema["exclusiveMinimum"].get<int64_t>() + 1;
} }
if (schema.contains("maximum")) { if (schema.contains("maximum")) {
max_value = schema["maximum"].get<int>(); max_value = schema["maximum"].get<int64_t>();
} else if (schema.contains("exclusiveMaximum")) { } else if (schema.contains("exclusiveMaximum")) {
max_value = schema["exclusiveMaximum"].get<int>() - 1; max_value = schema["exclusiveMaximum"].get<int64_t>() - 1;
} }
std::stringstream out; std::stringstream out;
out << "("; out << "(";

View File

@ -5,6 +5,7 @@
#include <map> #include <map>
static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = { static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_CLIP, "clip" }, // dummy, only used by llama-quantize
{ LLM_ARCH_LLAMA, "llama" }, { LLM_ARCH_LLAMA, "llama" },
{ LLM_ARCH_LLAMA4, "llama4" }, { LLM_ARCH_LLAMA4, "llama4" },
{ LLM_ARCH_DECI, "deci" }, { LLM_ARCH_DECI, "deci" },
@ -85,6 +86,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" }, { LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
{ LLM_ARCH_PLM, "plm" }, { LLM_ARCH_PLM, "plm" },
{ LLM_ARCH_BAILINGMOE, "bailingmoe" }, { LLM_ARCH_BAILINGMOE, "bailingmoe" },
{ LLM_ARCH_BAILINGMOE2, "bailingmoe2" },
{ LLM_ARCH_DOTS1, "dots1" }, { LLM_ARCH_DOTS1, "dots1" },
{ LLM_ARCH_ARCEE, "arcee" }, { LLM_ARCH_ARCEE, "arcee" },
{ LLM_ARCH_ERNIE4_5, "ernie4_5" }, { LLM_ARCH_ERNIE4_5, "ernie4_5" },
@ -135,6 +137,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_EXPERT_COUNT, "%s.expert_count" }, { LLM_KV_EXPERT_COUNT, "%s.expert_count" },
{ LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" }, { LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" },
{ LLM_KV_EXPERT_SHARED_COUNT, "%s.expert_shared_count" }, { LLM_KV_EXPERT_SHARED_COUNT, "%s.expert_shared_count" },
{ LLM_KV_EXPERT_GROUP_COUNT, "%s.expert_group_count" },
{ LLM_KV_EXPERT_GROUP_USED_COUNT, "%s.expert_group_used_count" },
{ LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" }, { LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" },
{ LLM_KV_EXPERT_WEIGHTS_NORM, "%s.expert_weights_norm" }, { LLM_KV_EXPERT_WEIGHTS_NORM, "%s.expert_weights_norm" },
{ LLM_KV_EXPERT_GATING_FUNC, "%s.expert_gating_func" }, { LLM_KV_EXPERT_GATING_FUNC, "%s.expert_gating_func" },
@ -277,6 +281,10 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
}; };
static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_NAMES = { static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_NAMES = {
{
LLM_ARCH_CLIP,
{},
},
{ {
LLM_ARCH_LLAMA, LLM_ARCH_LLAMA,
{ {
@ -1961,6 +1969,38 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
}, },
}, },
{
LLM_ARCH_BAILINGMOE2,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
{ LLM_TENSOR_NEXTN_EH_PROJ, "blk.%d.nextn.eh_proj" },
{ LLM_TENSOR_NEXTN_EMBED_TOKENS, "blk.%d.nextn.embed_tokens" },
{ LLM_TENSOR_NEXTN_ENORM, "blk.%d.nextn.enorm" },
{ LLM_TENSOR_NEXTN_HNORM, "blk.%d.nextn.hnorm" },
{ LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "blk.%d.nextn.shared_head_head" },
{ LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "blk.%d.nextn.shared_head_norm" },
{ LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" },
},
},
{ {
LLM_ARCH_DOTS1, LLM_ARCH_DOTS1,
{ {

View File

@ -9,6 +9,7 @@
// //
enum llm_arch { enum llm_arch {
LLM_ARCH_CLIP,
LLM_ARCH_LLAMA, LLM_ARCH_LLAMA,
LLM_ARCH_LLAMA4, LLM_ARCH_LLAMA4,
LLM_ARCH_DECI, LLM_ARCH_DECI,
@ -89,6 +90,7 @@ enum llm_arch {
LLM_ARCH_WAVTOKENIZER_DEC, LLM_ARCH_WAVTOKENIZER_DEC,
LLM_ARCH_PLM, LLM_ARCH_PLM,
LLM_ARCH_BAILINGMOE, LLM_ARCH_BAILINGMOE,
LLM_ARCH_BAILINGMOE2,
LLM_ARCH_DOTS1, LLM_ARCH_DOTS1,
LLM_ARCH_ARCEE, LLM_ARCH_ARCEE,
LLM_ARCH_ERNIE4_5, LLM_ARCH_ERNIE4_5,
@ -139,6 +141,8 @@ enum llm_kv {
LLM_KV_EXPERT_COUNT, LLM_KV_EXPERT_COUNT,
LLM_KV_EXPERT_USED_COUNT, LLM_KV_EXPERT_USED_COUNT,
LLM_KV_EXPERT_SHARED_COUNT, LLM_KV_EXPERT_SHARED_COUNT,
LLM_KV_EXPERT_GROUP_COUNT,
LLM_KV_EXPERT_GROUP_USED_COUNT,
LLM_KV_EXPERT_WEIGHTS_SCALE, LLM_KV_EXPERT_WEIGHTS_SCALE,
LLM_KV_EXPERT_WEIGHTS_NORM, LLM_KV_EXPERT_WEIGHTS_NORM,
LLM_KV_EXPERT_GATING_FUNC, LLM_KV_EXPERT_GATING_FUNC,

View File

@ -123,7 +123,7 @@ private:
uint32_t n_seq_max; uint32_t n_seq_max;
uint32_t n_outputs; uint32_t n_outputs;
std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id std::array<llama_seq_id, 1> seq_id_0 = {{ 0 }}; // default sequence id
std::vector<llama_pos> pos; std::vector<llama_pos> pos;
std::vector<int32_t> n_seq_id; std::vector<int32_t> n_seq_id;

View File

@ -63,6 +63,8 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
{ "megrez", LLM_CHAT_TEMPLATE_MEGREZ }, { "megrez", LLM_CHAT_TEMPLATE_MEGREZ },
{ "yandex", LLM_CHAT_TEMPLATE_YANDEX }, { "yandex", LLM_CHAT_TEMPLATE_YANDEX },
{ "bailing", LLM_CHAT_TEMPLATE_BAILING }, { "bailing", LLM_CHAT_TEMPLATE_BAILING },
{ "bailing-think", LLM_CHAT_TEMPLATE_BAILING_THINK },
{ "bailing2", LLM_CHAT_TEMPLATE_BAILING2 },
{ "llama4", LLM_CHAT_TEMPLATE_LLAMA4 }, { "llama4", LLM_CHAT_TEMPLATE_LLAMA4 },
{ "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM }, { "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM },
{ "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE }, { "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE },
@ -191,6 +193,10 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
return LLM_CHAT_TEMPLATE_YANDEX; return LLM_CHAT_TEMPLATE_YANDEX;
} else if (tmpl_contains("<role>ASSISTANT</role>") && tmpl_contains("'HUMAN'")) { } else if (tmpl_contains("<role>ASSISTANT</role>") && tmpl_contains("'HUMAN'")) {
return LLM_CHAT_TEMPLATE_BAILING; return LLM_CHAT_TEMPLATE_BAILING;
} else if (tmpl_contains("<role>ASSISTANT</role>") && tmpl_contains("\"HUMAN\"") && tmpl_contains("<think>")) {
return LLM_CHAT_TEMPLATE_BAILING_THINK;
} else if (tmpl_contains("<role>ASSISTANT</role>") && tmpl_contains("<role>HUMAN</role>") && tmpl_contains("<|role_end|>")) {
return LLM_CHAT_TEMPLATE_BAILING2;
} else if (tmpl_contains("<|header_start|>") && tmpl_contains("<|header_end|>")) { } else if (tmpl_contains("<|header_start|>") && tmpl_contains("<|header_end|>")) {
return LLM_CHAT_TEMPLATE_LLAMA4; return LLM_CHAT_TEMPLATE_LLAMA4;
} else if (tmpl_contains("<|endofuserprompt|>")) { } else if (tmpl_contains("<|endofuserprompt|>")) {
@ -644,8 +650,8 @@ int32_t llm_chat_apply_template(
if (add_ass) { if (add_ass) {
ss << " Ассистент:[SEP]"; ss << " Ассистент:[SEP]";
} }
} else if (tmpl == LLM_CHAT_TEMPLATE_BAILING) { } else if (tmpl == LLM_CHAT_TEMPLATE_BAILING || tmpl == LLM_CHAT_TEMPLATE_BAILING_THINK) {
// Bailing (Ling) template // Bailing (Ling/Ring) template
for (auto message : chat) { for (auto message : chat) {
std::string role(message->role); std::string role(message->role);
@ -658,6 +664,33 @@ int32_t llm_chat_apply_template(
ss << "<role>" << role << "</role>" << message->content; ss << "<role>" << role << "</role>" << message->content;
} }
if (add_ass) {
ss << "<role>ASSISTANT</role>";
if (tmpl == LLM_CHAT_TEMPLATE_BAILING_THINK) {
ss << "<think>";
}
}
} else if (tmpl == LLM_CHAT_TEMPLATE_BAILING2) {
// Bailing2 (Ling 2.0) template
bool has_system = !chat.empty() && std::string(chat[0]->role) == "system";
if (!has_system) {
ss << "<role>SYSTEM</role>detailed thinking off<|role_end|>";
}
for (auto message : chat) {
std::string role(message->role);
if (role == "user") {
role = "HUMAN";
} else {
std::transform(role.begin(), role.end(), role.begin(), ::toupper);
}
ss << "<role>" << role << "</role>" << message->content << "<|role_end|>";
}
if (add_ass) { if (add_ass) {
ss << "<role>ASSISTANT</role>"; ss << "<role>ASSISTANT</role>";
} }

View File

@ -42,6 +42,8 @@ enum llm_chat_template {
LLM_CHAT_TEMPLATE_MEGREZ, LLM_CHAT_TEMPLATE_MEGREZ,
LLM_CHAT_TEMPLATE_YANDEX, LLM_CHAT_TEMPLATE_YANDEX,
LLM_CHAT_TEMPLATE_BAILING, LLM_CHAT_TEMPLATE_BAILING,
LLM_CHAT_TEMPLATE_BAILING_THINK,
LLM_CHAT_TEMPLATE_BAILING2,
LLM_CHAT_TEMPLATE_LLAMA4, LLM_CHAT_TEMPLATE_LLAMA4,
LLM_CHAT_TEMPLATE_SMOLVLM, LLM_CHAT_TEMPLATE_SMOLVLM,
LLM_CHAT_TEMPLATE_DOTS1, LLM_CHAT_TEMPLATE_DOTS1,

View File

@ -2345,7 +2345,8 @@ llama_context * llama_init_from_model(
return nullptr; return nullptr;
} }
if (params.pooling_type != model->hparams.pooling_type) { if (params.pooling_type != LLAMA_POOLING_TYPE_UNSPECIFIED &&
params.pooling_type != model->hparams.pooling_type) {
//user-specified pooling-type is different from the model default //user-specified pooling-type is different from the model default
LLAMA_LOG_WARN("%s: model default pooling_type is [%d], but [%d] was specified\n", __func__, LLAMA_LOG_WARN("%s: model default pooling_type is [%d], but [%d] was specified\n", __func__,
model->hparams.pooling_type, params.pooling_type); model->hparams.pooling_type, params.pooling_type);

View File

@ -261,12 +261,17 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
} }
} }
static void print_mask(float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) { static void print_mask(const float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
LLAMA_LOG_DEBUG("%s: === Attention mask ===\n", __func__); LLAMA_LOG_DEBUG("%s: === Attention mask ===\n", __func__);
const char * swa_type_str = (swa_type == LLAMA_SWA_TYPE_NONE) ? "LLAMA_SWA_TYPE_NONE" : const char * swa_type_str = "unknown";
(swa_type == LLAMA_SWA_TYPE_STANDARD) ? "LLAMA_SWA_TYPE_STANDARD" :
(swa_type == LLAMA_SWA_TYPE_CHUNKED) ? "LLAMA_SWA_TYPE_CHUNKED" : switch (swa_type) {
(swa_type == LLAMA_SWA_TYPE_SYMMETRIC) ? "LLAMA_SWA_TYPE_SYMMETRIC" : "unknown"; case LLAMA_SWA_TYPE_NONE: swa_type_str = "LLAMA_SWA_TYPE_NONE"; break;
case LLAMA_SWA_TYPE_STANDARD: swa_type_str = "LLAMA_SWA_TYPE_STANDARD"; break;
case LLAMA_SWA_TYPE_CHUNKED: swa_type_str = "LLAMA_SWA_TYPE_CHUNKED"; break;
case LLAMA_SWA_TYPE_SYMMETRIC: swa_type_str = "LLAMA_SWA_TYPE_SYMMETRIC"; break;
};
LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swq_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str); LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swq_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str);
LLAMA_LOG_DEBUG("%s: '0' = can attend, '∞' = masked\n", __func__); LLAMA_LOG_DEBUG("%s: '0' = can attend, '∞' = masked\n", __func__);
LLAMA_LOG_DEBUG("%s: Rows = query tokens, Columns = key/value tokens\n\n", __func__); LLAMA_LOG_DEBUG("%s: Rows = query tokens, Columns = key/value tokens\n\n", __func__);
@ -295,51 +300,68 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
const int64_t n_kv = ubatch->n_tokens; const int64_t n_kv = ubatch->n_tokens;
const int64_t n_tokens = ubatch->n_tokens; const int64_t n_tokens = ubatch->n_tokens;
GGML_ASSERT(kq_mask); const auto fill_mask = [&](float * data, int n_swa, llama_swa_type swa_type) {
GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
float * data = (float *) kq_mask->data;
// [TAG_NO_CACHE_ISWA]
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "TODO: implement");
for (int h = 0; h < 1; ++h) { for (int h = 0; h < 1; ++h) {
for (int i1 = 0; i1 < n_tokens; ++i1) { for (int i1 = 0; i1 < n_tokens; ++i1) {
const llama_seq_id s1 = ubatch->seq_id[i1][0]; const llama_seq_id s1 = ubatch->seq_id[i1][0];
const llama_pos p1 = ubatch->pos[i1];
const uint64_t idst = h*(n_kv*n_tokens) + i1*n_kv;
for (int i0 = 0; i0 < n_tokens; ++i0) { for (int i0 = 0; i0 < n_tokens; ++i0) {
float f = -INFINITY;
for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) {
const llama_seq_id s0 = ubatch->seq_id[i0][0]; const llama_seq_id s0 = ubatch->seq_id[i0][0];
const llama_pos p0 = ubatch->pos[i0];
// mask different sequences
if (s0 != s1) { if (s0 != s1) {
continue; // skip different sequences continue;
} }
if (cparams.causal_attn && ubatch->pos[i0] > ubatch->pos[i1]) { // mask future tokens
continue; // skip future tokens for causal attention if (cparams.causal_attn && p0 > p1) {
continue;
} }
// TODO: this does not take into account that some layers are SWA and others are note (i.e. iSWA) [TAG_NO_CACHE_ISWA] // apply SWA if any
//if (hparams.is_masked_swa(ubatch->pos[i0], ubatch->pos[i1])) { if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) {
// continue; // skip masked tokens for SWA continue;
//} }
// TODO: reimplement this like in llama_kv_cache_unified data[idst + i0] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
if (hparams.use_alibi) {
f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]);
} else {
f = 0.0f;
}
}
data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f;
} }
} }
} }
};
{
GGML_ASSERT(self_kq_mask);
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
float * data = (float *) self_kq_mask->data;
std::fill(data, data + ggml_nelements(self_kq_mask), -INFINITY);
fill_mask(data, 0, LLAMA_SWA_TYPE_NONE);
if (debug) {
print_mask(data, n_tokens, n_kv, 0, LLAMA_SWA_TYPE_NONE);
}
}
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
GGML_ASSERT(self_kq_mask_swa);
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
float * data = (float *) self_kq_mask_swa->data;
std::fill(data, data + ggml_nelements(self_kq_mask_swa), -INFINITY);
fill_mask(data, hparams.n_swa, hparams.swa_type);
if (debug) { if (debug) {
print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type); print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type);
} }
}
} }
void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) { void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) {
@ -928,6 +950,31 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
cb(selection_probs, "ffn_moe_probs_biased", il); cb(selection_probs, "ffn_moe_probs_biased", il);
} }
// select top n_group_used expert groups
// https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/e815299b0bcbac849fa540c768ef21845365c9eb/modeling_deepseek.py#L440-L457
if (hparams.n_expert_groups > 1 && n_tokens > 0) {
const int64_t n_exp_per_group = n_expert / hparams.n_expert_groups;
// organize experts into n_expert_groups
ggml_tensor * selection_groups = ggml_reshape_3d(ctx0, selection_probs, n_exp_per_group, hparams.n_expert_groups, n_tokens); // [n_exp_per_group, n_expert_groups, n_tokens]
ggml_tensor * group_scores = ggml_top_k(ctx0, selection_groups, 2); // [2, n_expert_groups, n_tokens]
group_scores = ggml_get_rows(ctx0, ggml_reshape_4d(ctx0, selection_groups, 1, selection_groups->ne[0], selection_groups->ne[1], selection_groups->ne[2]), group_scores); // [1, 2, n_expert_groups, n_tokens]
// get top n_group_used expert groups
group_scores = ggml_sum_rows(ctx0, ggml_reshape_3d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2], group_scores->ne[3])); // [1, n_expert_groups, n_tokens]
group_scores = ggml_reshape_2d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2]); // [n_expert_groups, n_tokens]
ggml_tensor * expert_groups = ggml_top_k(ctx0, group_scores, hparams.n_group_used); // [n_group_used, n_tokens]
cb(expert_groups, "ffn_moe_group_topk", il);
// mask out the other groups
selection_probs = ggml_get_rows(ctx0, selection_groups, expert_groups); // [n_exp_per_group, n_group_used, n_tokens]
selection_probs = ggml_set_rows(ctx0, ggml_scale_bias(ctx0, selection_groups, 0.0f, -INFINITY), selection_probs, expert_groups); // [n_exp_per_group, n_expert_groups, n_tokens]
selection_probs = ggml_reshape_2d(ctx0, selection_probs, n_expert, n_tokens); // [n_expert, n_tokens]
cb(selection_probs, "ffn_moe_probs_masked", il);
}
// select experts // select experts
ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens] ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
cb(selected_experts->src[0], "ffn_moe_argsort", il); cb(selected_experts->src[0], "ffn_moe_argsort", il);
@ -959,6 +1006,11 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); // [1, n_tokens] ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); // [1, n_tokens]
cb(weights_sum, "ffn_moe_weights_sum", il); cb(weights_sum, "ffn_moe_weights_sum", il);
if (arch == LLM_ARCH_BAILINGMOE2) {
weights_sum = ggml_scale_bias(ctx0, weights_sum, 1.0, 1e-20);
cb(weights_sum, "ffn_moe_weights_sum_biased", il);
}
weights = ggml_div(ctx0, weights, weights_sum); // [n_expert_used, n_tokens] weights = ggml_div(ctx0, weights, weights_sum); // [n_expert_used, n_tokens]
cb(weights, "ffn_moe_weights_norm", il); cb(weights, "ffn_moe_weights_norm", il);
@ -1299,12 +1351,9 @@ ggml_tensor * llm_graph_context::build_attn_mha(
k = ggml_permute(ctx0, k, 0, 2, 1, 3); k = ggml_permute(ctx0, k, 0, 2, 1, 3);
v = ggml_permute(ctx0, v, 0, 2, 1, 3); v = ggml_permute(ctx0, v, 0, 2, 1, 3);
const auto n_kv = k->ne[1];
ggml_tensor * cur; ggml_tensor * cur;
// TODO: replace hardcoded padding with ggml-provided padding if (cparams.flash_attn && kq_b == nullptr) {
if (cparams.flash_attn && (n_kv % 256 == 0) && kq_b == nullptr) {
GGML_ASSERT(kq_b == nullptr && "Flash attention does not support KQ bias yet"); GGML_ASSERT(kq_b == nullptr && "Flash attention does not support KQ bias yet");
if (v_trans) { if (v_trans) {
@ -1419,10 +1468,20 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams); auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
// note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
inp->kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1); inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
ggml_set_input(inp->kq_mask); ggml_set_input(inp->self_kq_mask);
inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask; inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
ggml_set_input(inp->self_kq_mask_swa);
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
} else {
inp->self_kq_mask_swa = nullptr;
inp->self_kq_mask_swa_cnv = nullptr;
}
return (llm_graph_input_attn_no_cache *) res->add_input(std::move(inp)); return (llm_graph_input_attn_no_cache *) res->add_input(std::move(inp));
} }
@ -1447,7 +1506,9 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_build_forward_expand(gf, k_cur); ggml_build_forward_expand(gf, k_cur);
ggml_build_forward_expand(gf, v_cur); ggml_build_forward_expand(gf, v_cur);
const auto & kq_mask = inp->get_kq_mask(); const bool is_swa = hparams.is_swa(il);
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
// [TAG_NO_CACHE_PAD] // [TAG_NO_CACHE_PAD]
// TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams // TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams

View File

@ -257,10 +257,14 @@ public:
void set_input(const llama_ubatch * ubatch) override; void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * get_kq_mask() const { return kq_mask_cnv; } ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch, 1, 1] // n_tokens == n_batch
ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch, 1, 1] ggml_tensor * self_kq_mask = nullptr; // F32 [n_tokens, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_tokens, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_tokens, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_tokens, n_batch/n_stream, 1, n_stream]
const llama_hparams hparams; const llama_hparams hparams;
const llama_cparams cparams; const llama_cparams cparams;

View File

@ -74,6 +74,8 @@ struct llama_hparams {
uint32_t n_ff_chexp = 0; uint32_t n_ff_chexp = 0;
uint32_t n_expert_shared = 0; uint32_t n_expert_shared = 0;
uint32_t n_norm_groups = 0; uint32_t n_norm_groups = 0;
uint32_t n_expert_groups = 0;
uint32_t n_group_used = 0;
uint32_t n_group_experts = 0; uint32_t n_group_experts = 0;
float expert_group_scale = 0.05f; float expert_group_scale = 0.05f;

View File

@ -114,9 +114,12 @@ const char * llm_type_name(llm_type type) {
case LLM_TYPE_17B_16E: return "17Bx16E (Scout)"; case LLM_TYPE_17B_16E: return "17Bx16E (Scout)";
case LLM_TYPE_17B_128E: return "17Bx128E (Maverick)"; case LLM_TYPE_17B_128E: return "17Bx128E (Maverick)";
case LLM_TYPE_A13B: return "A13B"; case LLM_TYPE_A13B: return "A13B";
case LLM_TYPE_7B_A1B: return "7B.A1B";
case LLM_TYPE_8B_A1B: return "8B.A1B"; case LLM_TYPE_8B_A1B: return "8B.A1B";
case LLM_TYPE_16B_A1B: return "16B.A1B";
case LLM_TYPE_21B_A3B: return "21B.A3B"; case LLM_TYPE_21B_A3B: return "21B.A3B";
case LLM_TYPE_30B_A3B: return "30B.A3B"; case LLM_TYPE_30B_A3B: return "30B.A3B";
case LLM_TYPE_100B_A6B: return "100B.A6B";
case LLM_TYPE_106B_A12B: return "106B.A12B"; case LLM_TYPE_106B_A12B: return "106B.A12B";
case LLM_TYPE_235B_A22B: return "235B.A22B"; case LLM_TYPE_235B_A22B: return "235B.A22B";
case LLM_TYPE_300B_A47B: return "300B.A47B"; case LLM_TYPE_300B_A47B: return "300B.A47B";
@ -401,6 +404,19 @@ static buft_list_t make_gpu_buft_list(ggml_backend_dev_t dev, llama_split_mode s
// add the device default buffer type // add the device default buffer type
buft_list.emplace_back(dev, ggml_backend_dev_buffer_type(dev)); buft_list.emplace_back(dev, ggml_backend_dev_buffer_type(dev));
// add the device extra buffer type (if any)
ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev);
auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t)
ggml_backend_reg_get_proc_address(reg, "ggml_backend_dev_get_extra_bufts");
if (ggml_backend_dev_get_extra_bufts_fn) {
ggml_backend_buffer_type_t * extra_bufts = ggml_backend_dev_get_extra_bufts_fn(dev);
while (extra_bufts && *extra_bufts) {
buft_list.emplace_back(dev, *extra_bufts);
++extra_bufts;
}
}
return buft_list; return buft_list;
} }
@ -421,11 +437,8 @@ struct llama_model::impl {
llama_mlocks mlock_bufs; llama_mlocks mlock_bufs;
llama_mlocks mlock_mmaps; llama_mlocks mlock_mmaps;
// contexts where the model tensors metadata is stored // contexts where the model tensors metadata is stored as well ass the corresponding buffers:
std::vector<ggml_context_ptr> ctxs; std::vector<std::pair<ggml_context_ptr, ggml_backend_buffer_ptr>> ctxs_bufs;
// the model memory buffers for the tensor data
std::vector<ggml_backend_buffer_ptr> bufs;
buft_list_t cpu_buft_list; buft_list_t cpu_buft_list;
std::map<ggml_backend_dev_t, buft_list_t> gpu_buft_list; std::map<ggml_backend_dev_t, buft_list_t> gpu_buft_list;
@ -478,7 +491,8 @@ void llama_model::load_hparams(llama_model_loader & ml) {
ml.get_key(LLM_KV_GENERAL_NAME, name, false); ml.get_key(LLM_KV_GENERAL_NAME, name, false);
// everything past this point is not vocab-related // everything past this point is not vocab-related
if (hparams.vocab_only) { // for CLIP models, we only need to load tensors, no hparams
if (hparams.vocab_only || ml.get_arch() == LLM_ARCH_CLIP) {
return; return;
} }
@ -487,6 +501,8 @@ void llama_model::load_hparams(llama_model_loader & ml) {
ml.get_key(LLM_KV_BLOCK_COUNT, hparams.n_layer); ml.get_key(LLM_KV_BLOCK_COUNT, hparams.n_layer);
ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert, false); ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert, false);
ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false); ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false);
ml.get_key(LLM_KV_EXPERT_GROUP_COUNT, hparams.n_expert_groups, false);
ml.get_key(LLM_KV_EXPERT_GROUP_USED_COUNT, hparams.n_group_used, false);
if (arch == LLM_ARCH_WAVTOKENIZER_DEC) { if (arch == LLM_ARCH_WAVTOKENIZER_DEC) {
ml.get_key(LLM_KV_FEATURES_LENGTH, hparams.n_embd_features); ml.get_key(LLM_KV_FEATURES_LENGTH, hparams.n_embd_features);
@ -502,8 +518,15 @@ void llama_model::load_hparams(llama_model_loader & ml) {
GGML_ASSERT(hparams.n_expert_used <= hparams.n_expert); GGML_ASSERT(hparams.n_expert_used <= hparams.n_expert);
if (hparams.n_expert > 0) { if (hparams.n_expert > 0) {
GGML_ASSERT(hparams.n_expert_used > 0); GGML_ASSERT(hparams.n_expert_used > 0);
GGML_ASSERT(hparams.n_expert_groups < hparams.n_expert);
if (hparams.n_expert_groups > 1) {
GGML_ASSERT(hparams.n_expert % hparams.n_expert_groups == 0);
GGML_ASSERT(hparams.n_group_used > 0);
GGML_ASSERT(hparams.n_group_used < hparams.n_expert_groups);
}
} else { } else {
GGML_ASSERT(hparams.n_expert_used == 0); GGML_ASSERT(hparams.n_expert_used == 0);
GGML_ASSERT(hparams.n_expert_groups == 0);
} }
std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0); std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0);
@ -1845,8 +1868,10 @@ void llama_model::load_hparams(llama_model_loader & ml) {
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
switch (hparams.n_layer) { switch (hparams.n_embd) {
// TODO: Add llm type label (not sure this is useful) case 1536: type = LLM_TYPE_7B_A1B; break;
case 2048: case 2560: type = LLM_TYPE_3B; break;
case 4096: type = LLM_TYPE_32B; break;
default: type = LLM_TYPE_UNKNOWN; default: type = LLM_TYPE_UNKNOWN;
} }
@ -1902,6 +1927,29 @@ void llama_model::load_hparams(llama_model_loader & ml) {
default: type = LLM_TYPE_UNKNOWN; default: type = LLM_TYPE_UNKNOWN;
} }
} break; } break;
case LLM_ARCH_BAILINGMOE2:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead);
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp);
ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared);
ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale);
ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false);
ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func);
ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false);
// TODO: when MTP is implemented, this should probably be updated if needed
hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers;
switch (hparams.n_layer) {
case 20: type = LLM_TYPE_16B_A1B; break;
case 21: type = LLM_TYPE_16B_A1B; break;
case 32: type = LLM_TYPE_100B_A6B; break;
case 33: type = LLM_TYPE_100B_A6B; break;
default: type = LLM_TYPE_UNKNOWN;
}
} break;
case LLM_ARCH_DOTS1: case LLM_ARCH_DOTS1:
{ {
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
@ -2196,7 +2244,14 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
max_n_tensors += n_layer*2; // duplicated rope freq tensors max_n_tensors += n_layer*2; // duplicated rope freq tensors
const size_t ctx_size = ggml_tensor_overhead()*max_n_tensors; const size_t ctx_size = ggml_tensor_overhead()*max_n_tensors;
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map; // define a comparator for the buft -> ctx map to ensure that the order is well-defined:
struct ggml_backend_buft_comparator {
bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const {
return ggml_backend_buft_name(lhs) < ggml_backend_buft_name(rhs);
}
};
std::map<ggml_backend_buffer_type_t, ggml_context_ptr, ggml_backend_buft_comparator> ctx_map;
auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
auto it = ctx_map.find(buft); auto it = ctx_map.find(buft);
if (it == ctx_map.end()) { if (it == ctx_map.end()) {
@ -2211,12 +2266,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
throw std::runtime_error(format("failed to create ggml context")); throw std::runtime_error(format("failed to create ggml context"));
} }
ctx_map[buft] = ctx; ctx_map.emplace(buft, ctx);
pimpl->ctxs.emplace_back(ctx);
return ctx; return ctx;
} }
return it->second; return it->second.get();
}; };
const auto TENSOR_DUPLICATED = llama_model_loader::TENSOR_DUPLICATED; const auto TENSOR_DUPLICATED = llama_model_loader::TENSOR_DUPLICATED;
@ -5534,6 +5588,70 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
} }
} break; } break;
case LLM_ARCH_BAILINGMOE2:
{
const int64_t n_ff_exp = hparams.n_ff_exp;
const int64_t n_expert_shared = hparams.n_expert_shared;
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
// output
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
GGML_ASSERT(n_expert > 0 && "n_expert must be > 0 for bailingmoe2");
GGML_ASSERT(n_expert_used > 0 && "n_expert_used must be > 0 for bailingmoe2");
for (int i = 0; i < n_layer; ++i) {
int flags = 0;
if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) {
// skip all tensors in the NextN layers
flags |= TENSOR_SKIP;
}
auto & layer = layers[i];
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags);
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, flags);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, flags);
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, flags);
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, flags);
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags);
if (static_cast<uint32_t>(i) >= hparams.n_layer_dense_lead) { // MoE layers
const int64_t n_ff_shexp = (hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff_exp) * n_expert_shared;
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, flags);
layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED | flags);
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags);
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, flags);
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags);
layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp}, flags);
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, flags);
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, flags);
} else { // Dense layers
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, flags);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, flags);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, flags);
}
// NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers
if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) {
layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags);
layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED | flags);
layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags);
layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags);
layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED | flags);
layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED | flags);
layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, flags);
}
}
} break;
case LLM_ARCH_DOTS1: case LLM_ARCH_DOTS1:
{ {
const int64_t n_ff_exp = hparams.n_ff_exp; const int64_t n_ff_exp = hparams.n_ff_exp;
@ -6079,16 +6197,15 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
pimpl->mappings.reserve(ml.mappings.size()); pimpl->mappings.reserve(ml.mappings.size());
// create the backend buffers // create the backend buffers
std::vector<std::pair<ggml_context *, llama_buf_map>> ctx_bufs; std::vector<std::pair<ggml_context *, llama_buf_map>> ctx_buf_maps;
ctx_bufs.reserve(ctx_map.size()); ctx_buf_maps.reserve(ctx_map.size());
// Ensure we have enough capacity for the maximum backend buffer we will potentially create // Ensure we have enough capacity for the maximum backend buffer we will potentially create
const size_t n_max_backend_buffer = ctx_map.size() * ml.files.size(); const size_t n_max_backend_buffer = ctx_map.size() * ml.files.size();
pimpl->bufs.reserve(n_max_backend_buffer); pimpl->ctxs_bufs.reserve(n_max_backend_buffer);
for (auto & it : ctx_map) { for (auto & [buft, ctx_ptr] : ctx_map) {
ggml_backend_buffer_type_t buft = it.first; ggml_context * ctx = ctx_ptr.get();
ggml_context * ctx = it.second;
// skip contexts without tensors // skip contexts without tensors
if (ggml_get_first_tensor(ctx) == nullptr) { if (ggml_get_first_tensor(ctx) == nullptr) {
@ -6112,6 +6229,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
bool buffer_from_host_ptr_supported = props.caps.buffer_from_host_ptr; bool buffer_from_host_ptr_supported = props.caps.buffer_from_host_ptr;
bool is_default_buft = buft == ggml_backend_dev_buffer_type(dev); bool is_default_buft = buft == ggml_backend_dev_buffer_type(dev);
ggml_backend_buffer_t buf = nullptr;
if (ml.use_mmap && use_mmap_buffer && buffer_from_host_ptr_supported && is_default_buft) { if (ml.use_mmap && use_mmap_buffer && buffer_from_host_ptr_supported && is_default_buft) {
for (uint32_t idx = 0; idx < ml.files.size(); idx++) { for (uint32_t idx = 0; idx < ml.files.size(); idx++) {
// only the mmap region containing the tensors in the model is mapped to the backend buffer // only the mmap region containing the tensors in the model is mapped to the backend buffer
@ -6124,20 +6242,18 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
continue; continue;
} }
const size_t max_size = ggml_get_max_tensor_size(ctx); const size_t max_size = ggml_get_max_tensor_size(ctx);
ggml_backend_buffer_t buf = ggml_backend_dev_buffer_from_host_ptr(dev, (char *) addr + first, last - first, max_size); buf = ggml_backend_dev_buffer_from_host_ptr(dev, (char *) addr + first, last - first, max_size);
if (buf == nullptr) { if (buf == nullptr) {
throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft))); throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft)));
} }
pimpl->bufs.emplace_back(buf);
buf_map.emplace(idx, buf); buf_map.emplace(idx, buf);
} }
} }
else { else {
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
if (buf == nullptr) { if (buf == nullptr) {
throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft))); throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft)));
} }
pimpl->bufs.emplace_back(buf);
if (use_mlock && ggml_backend_buffer_is_host(buf)) { if (use_mlock && ggml_backend_buffer_is_host(buf)) {
pimpl->mlock_bufs.emplace_back(new llama_mlock); pimpl->mlock_bufs.emplace_back(new llama_mlock);
auto & mlock_buf = pimpl->mlock_bufs.back(); auto & mlock_buf = pimpl->mlock_bufs.back();
@ -6148,10 +6264,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
buf_map.emplace(idx, buf); buf_map.emplace(idx, buf);
} }
} }
pimpl->ctxs_bufs.emplace_back(std::move(ctx_ptr), buf);
if (pimpl->bufs.empty()) {
throw std::runtime_error("failed to allocate buffer");
}
for (auto & buf : buf_map) { for (auto & buf : buf_map) {
// indicate that this buffer contains weights // indicate that this buffer contains weights
@ -6159,7 +6272,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
ggml_backend_buffer_set_usage(buf.second, GGML_BACKEND_BUFFER_USAGE_WEIGHTS); ggml_backend_buffer_set_usage(buf.second, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
} }
ctx_bufs.emplace_back(ctx, buf_map); ctx_buf_maps.emplace_back(ctx, buf_map);
} }
if (llama_supports_gpu_offload()) { if (llama_supports_gpu_offload()) {
@ -6177,22 +6290,20 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
} }
// print memory requirements per buffer type // print memory requirements per buffer type
for (auto & buf : pimpl->bufs) { for (auto & [_, buf] : pimpl->ctxs_bufs) {
LLAMA_LOG_INFO("%s: %12s model buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf.get()), ggml_backend_buffer_get_size(buf.get()) / 1024.0 / 1024.0); LLAMA_LOG_INFO("%s: %12s model buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf.get()), ggml_backend_buffer_get_size(buf.get()) / 1024.0 / 1024.0);
} }
// populate tensors_by_name // populate tensors_by_name
for (auto & ctx : pimpl->ctxs) { for (auto & [ctx, _] : pimpl->ctxs_bufs) {
for (auto * cur = ggml_get_first_tensor(ctx.get()); cur != NULL; cur = ggml_get_next_tensor(ctx.get(), cur)) { for (auto * cur = ggml_get_first_tensor(ctx.get()); cur != NULL; cur = ggml_get_next_tensor(ctx.get(), cur)) {
tensors_by_name.emplace_back(ggml_get_name(cur), cur); tensors_by_name.emplace_back(ggml_get_name(cur), cur);
} }
} }
// load tensor data // load tensor data
for (auto & it : ctx_bufs) { for (auto & [ctx, buf_map] : ctx_buf_maps) {
ggml_context * ctx = it.first; if (!ml.load_all_data(ctx, buf_map, use_mlock ? &pimpl->mlock_mmaps : NULL, params.progress_callback, params.progress_callback_user_data)) {
auto & bufs = it.second;
if (!ml.load_all_data(ctx, bufs, use_mlock ? &pimpl->mlock_mmaps : NULL, params.progress_callback, params.progress_callback_user_data)) {
return false; return false;
} }
} }
@ -6232,8 +6343,8 @@ size_t llama_model::n_devices() const {
std::map<ggml_backend_buffer_type_t, size_t> llama_model::memory_breakdown() const { std::map<ggml_backend_buffer_type_t, size_t> llama_model::memory_breakdown() const {
std::map<ggml_backend_buffer_type_t, size_t> ret; std::map<ggml_backend_buffer_type_t, size_t> ret;
for (const ggml_backend_buffer_ptr & buf_ptr : pimpl->bufs) { for (const auto & [_, buf] : pimpl->ctxs_bufs) {
ret[ggml_backend_buffer_get_type(buf_ptr.get())] += ggml_backend_buffer_get_size(buf_ptr.get()); ret[ggml_backend_buffer_get_type(buf.get())] += ggml_backend_buffer_get_size(buf.get());
} }
return ret; return ret;
} }
@ -6396,6 +6507,19 @@ void llama_model::print_info() const {
LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm);
} }
if (arch == LLM_ARCH_BAILINGMOE2) {
LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead);
LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp);
LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared);
LLAMA_LOG_INFO("%s: n_expert_groups = %d\n", __func__, hparams.n_expert_groups);
LLAMA_LOG_INFO("%s: n_group_used = %d\n", __func__, hparams.n_group_used);
LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale);
LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm);
LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func));
LLAMA_LOG_INFO("%s: nextn_predict_layers = %d\n", __func__, hparams.nextn_predict_layers);
}
if (arch == LLM_ARCH_SMALLTHINKER || arch == LLM_ARCH_LFM2MOE) { if (arch == LLM_ARCH_SMALLTHINKER || arch == LLM_ARCH_LFM2MOE) {
LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func));
@ -11401,8 +11525,8 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
} }
}; };
struct llm_build_gemma_embedding_iswa : public llm_graph_context { struct llm_build_gemma_embedding : public llm_graph_context {
llm_build_gemma_embedding_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { llm_build_gemma_embedding(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_k; const int64_t n_embd_head = hparams.n_embd_head_k;
ggml_tensor * cur; ggml_tensor * cur;
@ -11419,8 +11543,7 @@ struct llm_build_gemma_embedding_iswa : public llm_graph_context {
// inp_pos - contains the positions // inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos(); ggml_tensor * inp_pos = build_inp_pos();
// TODO: support cacheless iSWA embeddings [TAG_NO_CACHE_ISWA] auto * inp_attn = build_attn_inp_no_cache();
auto * inp_attn = build_attn_inp_kv_iswa();
ggml_tensor * inp_out_ids = build_inp_out_ids(); ggml_tensor * inp_out_ids = build_inp_out_ids();
@ -17245,6 +17368,150 @@ struct llm_build_bailingmoe : public llm_graph_context {
} }
}; };
struct llm_build_bailingmoe2 : public llm_graph_context {
llm_build_bailingmoe2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
ggml_tensor * cur;
ggml_tensor * inpL;
inpL = build_inp_embd(model.tok_embd);
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv();
ggml_tensor * inp_out_ids = build_inp_out_ids();
const int n_transformer_layers = n_layer - hparams.nextn_predict_layers;
for (int il = 0; il < n_transformer_layers; ++il) {
ggml_tensor * inpSA = inpL;
// norm
cur = build_norm(inpL,
model.layers[il].attn_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "attn_norm", il);
// self_attention
{
cur = build_lora_mm(model.layers[il].wqkv, cur);
cb(cur, "wqkv", il);
ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd));
ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
cb(Qcur, "Qcur_normed", il);
Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
cb(Kcur, "Kcur_normed", il);
Kcur = ggml_rope_ext(
ctx0, Kcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
cur = build_attn(inp_attn,
model.layers[il].wo, model.layers[il].bo,
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
}
if (il == n_transformer_layers - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
ggml_tensor * sa_out = ggml_add(ctx0, cur, inpSA);
cb(sa_out, "sa_out", il);
// MoE branch
cur = build_norm(sa_out,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "ffn_norm", il);
if (static_cast<uint32_t>(il) < hparams.n_layer_dense_lead) {
cur = build_ffn(cur,
model.layers[il].ffn_up, NULL, NULL,
model.layers[il].ffn_gate, NULL, NULL,
model.layers[il].ffn_down, NULL, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(cur, "ffn_out", il);
} else {
ggml_tensor * moe_out =
build_moe_ffn(cur,
model.layers[il].ffn_gate_inp,
model.layers[il].ffn_up_exps,
model.layers[il].ffn_gate_exps,
model.layers[il].ffn_down_exps,
model.layers[il].ffn_exp_probs_b,
n_expert, n_expert_used,
LLM_FFN_SILU, hparams.expert_weights_norm,
true, hparams.expert_weights_scale,
(llama_expert_gating_func_type) hparams.expert_gating_func,
il);
cb(moe_out, "ffn_moe_out", il);
{
ggml_tensor * ffn_shexp = build_ffn(cur,
model.layers[il].ffn_up_shexp, NULL, NULL,
model.layers[il].ffn_gate_shexp, NULL, NULL,
model.layers[il].ffn_down_shexp, NULL, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(ffn_shexp, "ffn_shexp", il);
cur = ggml_add(ctx0, moe_out, ffn_shexp);
cb(cur, "ffn_out", il);
}
}
cur = ggml_add(ctx0, cur, sa_out);
cur = build_cvec(cur, il);
cb(cur, "l_out", il);
// input for next layer
inpL = cur;
}
cur = inpL;
cur = build_norm(cur,
model.output_norm, NULL,
LLM_NORM_RMS, -1);
cb(cur, "result_norm", -1);
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cb(cur, "result_output", -1);
res->t_logits = cur;
ggml_build_forward_expand(gf, cur);
}
};
struct llm_build_dots1 : public llm_graph_context { struct llm_build_dots1 : public llm_graph_context {
llm_build_dots1(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { llm_build_dots1(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_head = hparams.n_embd_head_v;
@ -17900,6 +18167,8 @@ struct llm_build_plamo2 : public llm_graph_context_mamba {
cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
cb(cur, "result_norm", -1); cb(cur, "result_norm", -1);
res->t_embd = cur;
// lm_head // lm_head
cur = build_lora_mm(model.output, cur); cur = build_lora_mm(model.output, cur);
cb(cur, "result_output", -1); cb(cur, "result_output", -1);
@ -19580,7 +19849,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
case LLM_ARCH_NOMIC_BERT_MOE: case LLM_ARCH_NOMIC_BERT_MOE:
case LLM_ARCH_NEO_BERT: case LLM_ARCH_NEO_BERT:
case LLM_ARCH_WAVTOKENIZER_DEC: case LLM_ARCH_WAVTOKENIZER_DEC:
//case LLM_ARCH_GEMMA_EMBEDDING: // TODO: disabled until the cacheless SWA logic is fixed [TAG_NO_CACHE_ISWA] case LLM_ARCH_GEMMA_EMBEDDING:
case LLM_ARCH_DREAM: case LLM_ARCH_DREAM:
case LLM_ARCH_LLADA: case LLM_ARCH_LLADA:
case LLM_ARCH_LLADA_MOE: case LLM_ARCH_LLADA_MOE:
@ -19873,7 +20142,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
} break; } break;
case LLM_ARCH_GEMMA_EMBEDDING: case LLM_ARCH_GEMMA_EMBEDDING:
{ {
llm = std::make_unique<llm_build_gemma_embedding_iswa>(*this, params); llm = std::make_unique<llm_build_gemma_embedding>(*this, params);
} break; } break;
case LLM_ARCH_STARCODER2: case LLM_ARCH_STARCODER2:
{ {
@ -20045,6 +20314,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
{ {
llm = std::make_unique<llm_build_bailingmoe>(*this, params); llm = std::make_unique<llm_build_bailingmoe>(*this, params);
} break; } break;
case LLM_ARCH_BAILINGMOE2:
{
llm = std::make_unique<llm_build_bailingmoe2>(*this, params);
} break;
case LLM_ARCH_SEED_OSS: case LLM_ARCH_SEED_OSS:
{ {
llm = std::make_unique<llm_build_seed_oss>(*this, params); llm = std::make_unique<llm_build_seed_oss>(*this, params);
@ -20220,6 +20493,7 @@ int32_t llama_n_head(const llama_model * model) {
llama_rope_type llama_model_rope_type(const llama_model * model) { llama_rope_type llama_model_rope_type(const llama_model * model) {
switch (model->arch) { switch (model->arch) {
// these models do not use RoPE // these models do not use RoPE
case LLM_ARCH_CLIP:
case LLM_ARCH_GPT2: case LLM_ARCH_GPT2:
case LLM_ARCH_GPTJ: case LLM_ARCH_GPTJ:
case LLM_ARCH_MPT: case LLM_ARCH_MPT:
@ -20311,6 +20585,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_EXAONE: case LLM_ARCH_EXAONE:
case LLM_ARCH_EXAONE4: case LLM_ARCH_EXAONE4:
case LLM_ARCH_MINICPM3: case LLM_ARCH_MINICPM3:
case LLM_ARCH_BAILINGMOE2:
case LLM_ARCH_DOTS1: case LLM_ARCH_DOTS1:
case LLM_ARCH_HUNYUAN_MOE: case LLM_ARCH_HUNYUAN_MOE:
case LLM_ARCH_OPENAI_MOE: case LLM_ARCH_OPENAI_MOE:

View File

@ -108,9 +108,12 @@ enum llm_type {
LLM_TYPE_17B_16E, // llama4 Scout LLM_TYPE_17B_16E, // llama4 Scout
LLM_TYPE_17B_128E, // llama4 Maverick LLM_TYPE_17B_128E, // llama4 Maverick
LLM_TYPE_A13B, LLM_TYPE_A13B,
LLM_TYPE_7B_A1B,
LLM_TYPE_8B_A1B, // lfm2moe LLM_TYPE_8B_A1B, // lfm2moe
LLM_TYPE_16B_A1B,
LLM_TYPE_21B_A3B, // Ernie MoE small LLM_TYPE_21B_A3B, // Ernie MoE small
LLM_TYPE_30B_A3B, LLM_TYPE_30B_A3B,
LLM_TYPE_100B_A6B,
LLM_TYPE_106B_A12B, // GLM-4.5-Air LLM_TYPE_106B_A12B, // GLM-4.5-Air
LLM_TYPE_235B_A22B, LLM_TYPE_235B_A22B,
LLM_TYPE_300B_A47B, // Ernie MoE big LLM_TYPE_300B_A47B, // Ernie MoE big

View File

@ -701,6 +701,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
}); });
} }
bool is_clip_model = false;
for (const auto * it : tensors) { for (const auto * it : tensors) {
const struct ggml_tensor * tensor = it->tensor; const struct ggml_tensor * tensor = it->tensor;
@ -714,12 +715,14 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
} else if (name == LLM_TN(model.arch)(LLM_TENSOR_OUTPUT, "weight")) { } else if (name == LLM_TN(model.arch)(LLM_TENSOR_OUTPUT, "weight")) {
qs.has_output = true; qs.has_output = true;
} }
is_clip_model |= name.rfind("mm.", 0) == 0; // check the "mm." prefix
} }
qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)model.hparams.n_layer; qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)model.hparams.n_layer;
// sanity checks for models that have attention layers // sanity checks for models that have attention layers
if (qs.n_attention_wv != 0) if (qs.n_attention_wv != 0 && !is_clip_model)
{ {
const auto & n_head_kv_iter = model.hparams.n_head_kv_arr.begin(); const auto & n_head_kv_iter = model.hparams.n_head_kv_arr.begin();
// attention layers have a non-zero number of kv heads // attention layers have a non-zero number of kv heads
@ -881,6 +884,9 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
// do not quantize relative position bias (T5) // do not quantize relative position bias (T5)
quantize &= name.find("attn_rel_b.weight") == std::string::npos; quantize &= name.find("attn_rel_b.weight") == std::string::npos;
// do not quantize specific multimodal tensors
quantize &= name.find(".position_embd.") == std::string::npos;
ggml_type new_type; ggml_type new_type;
void * new_data; void * new_data;
size_t new_size; size_t new_size;

View File

@ -1957,6 +1957,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
clean_spaces = false; clean_spaces = false;
} else if ( } else if (
tokenizer_pre == "bailingmoe" || tokenizer_pre == "bailingmoe" ||
tokenizer_pre == "bailingmoe2" ||
tokenizer_pre == "llada-moe") { tokenizer_pre == "llada-moe") {
pre_type = LLAMA_VOCAB_PRE_TYPE_BAILINGMOE; pre_type = LLAMA_VOCAB_PRE_TYPE_BAILINGMOE;
clean_spaces = false; clean_spaces = false;

View File

@ -124,6 +124,9 @@ static int llama_model_load(const std::string & fname, std::vector<std::string>
} catch(const std::exception & e) { } catch(const std::exception & e) {
throw std::runtime_error("error loading model hyperparameters: " + std::string(e.what())); throw std::runtime_error("error loading model hyperparameters: " + std::string(e.what()));
} }
if (model.arch == LLM_ARCH_CLIP) {
throw std::runtime_error("CLIP cannot be used as main model, use it with --mmproj instead");
}
try { try {
model.load_vocab(ml); model.load_vocab(ml);
} catch(const std::exception & e) { } catch(const std::exception & e) {
@ -314,6 +317,7 @@ struct llama_model * llama_model_load_from_splits(
LLAMA_LOG_ERROR("%s: list of splits is empty\n", __func__); LLAMA_LOG_ERROR("%s: list of splits is empty\n", __func__);
return nullptr; return nullptr;
} }
splits.reserve(n_paths);
for (size_t i = 0; i < n_paths; ++i) { for (size_t i = 0; i < n_paths; ++i) {
splits.push_back(paths[i]); splits.push_back(paths[i]);
} }

View File

@ -30,6 +30,7 @@
#define KEY_LAYER_NORM_EPS "clip.%s.attention.layer_norm_epsilon" #define KEY_LAYER_NORM_EPS "clip.%s.attention.layer_norm_epsilon"
// vision-specific // vision-specific
#define KEY_VISION_PROJ_TYPE "clip.vision.projector_type" // for models with mixed modalities
#define KEY_IMAGE_SIZE "clip.vision.image_size" #define KEY_IMAGE_SIZE "clip.vision.image_size"
#define KEY_PREPROC_IMAGE_SIZE "clip.vision.preproc_image_size" #define KEY_PREPROC_IMAGE_SIZE "clip.vision.preproc_image_size"
#define KEY_PATCH_SIZE "clip.vision.patch_size" #define KEY_PATCH_SIZE "clip.vision.patch_size"
@ -48,6 +49,7 @@
#define KEY_MINICPMV_QUERY_NUM "clip.minicpmv_query_num" #define KEY_MINICPMV_QUERY_NUM "clip.minicpmv_query_num"
// audio-specific // audio-specific
#define KEY_AUDIO_PROJ_TYPE "clip.audio.projector_type" // for models with mixed modalities
#define KEY_A_NUM_MEL_BINS "clip.audio.num_mel_bins" #define KEY_A_NUM_MEL_BINS "clip.audio.num_mel_bins"
#define KEY_A_PROJ_STACK_FACTOR "clip.audio.projector.stack_factor" #define KEY_A_PROJ_STACK_FACTOR "clip.audio.projector.stack_factor"

View File

@ -2234,15 +2234,27 @@ struct clip_model_loader {
// projector type // projector type
std::string proj_type; std::string proj_type;
{ {
// default key
get_string(KEY_PROJ_TYPE, proj_type, false); get_string(KEY_PROJ_TYPE, proj_type, false);
if (!proj_type.empty()) {
model.proj_type = clip_projector_type_from_string(proj_type); // for models with mixed modalities
if (proj_type.empty()) {
if (modality == CLIP_MODALITY_VISION) {
get_string(KEY_VISION_PROJ_TYPE, proj_type, false);
} else if (modality == CLIP_MODALITY_AUDIO) {
get_string(KEY_AUDIO_PROJ_TYPE, proj_type, false);
} else {
GGML_ABORT("unknown modality");
} }
}
model.proj_type = clip_projector_type_from_string(proj_type);
if (model.proj_type == PROJECTOR_TYPE_UNKNOWN) { if (model.proj_type == PROJECTOR_TYPE_UNKNOWN) {
throw std::runtime_error(string_format("%s: unknown projector type: %s\n", __func__, proj_type.c_str())); throw std::runtime_error(string_format("%s: unknown projector type: %s\n", __func__, proj_type.c_str()));
} }
// correct arch for multimodal models // correct arch for multimodal models (legacy method)
if (model.proj_type == PROJECTOR_TYPE_QWEN25O) { if (model.proj_type == PROJECTOR_TYPE_QWEN25O) {
model.proj_type = modality == CLIP_MODALITY_VISION model.proj_type = modality == CLIP_MODALITY_VISION
? PROJECTOR_TYPE_QWEN25VL ? PROJECTOR_TYPE_QWEN25VL

View File

@ -23,7 +23,7 @@ problem.
8 files changed, 21 insertions(+), 2 deletions(-) 8 files changed, 21 insertions(+), 2 deletions(-)
diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp
index ff9135fe..8ba86f82 100644 index ff9135fe2..8ba86f824 100644
--- a/ggml/src/ggml-backend.cpp --- a/ggml/src/ggml-backend.cpp
+++ b/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp
@@ -113,7 +113,6 @@ void ggml_backend_buffer_free(ggml_backend_buffer_t buffer) { @@ -113,7 +113,6 @@ void ggml_backend_buffer_free(ggml_backend_buffer_t buffer) {
@ -64,18 +64,18 @@ index ff9135fe..8ba86f82 100644
/* .init_tensor = */ NULL, // no initialization required /* .init_tensor = */ NULL, // no initialization required
/* .memset_tensor = */ ggml_backend_cpu_buffer_memset_tensor, /* .memset_tensor = */ ggml_backend_cpu_buffer_memset_tensor,
diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp
index ad1adba6..7d44f74f 100755 index 8bd5449f1..01e2df61a 100644
--- a/ggml/src/ggml-cann/ggml-cann.cpp --- a/ggml/src/ggml-cann/ggml-cann.cpp
+++ b/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp
@@ -843,6 +843,7 @@ static void ggml_backend_cann_buffer_free_buffer( @@ -820,6 +820,7 @@ static bool ggml_backend_buffer_is_cann(ggml_backend_buffer_t buffer) {
ggml_backend_cann_buffer_context* ctx = static void ggml_backend_cann_buffer_free_buffer(ggml_backend_buffer_t buffer) {
(ggml_backend_cann_buffer_context*)buffer->context; ggml_backend_cann_buffer_context * ctx = (ggml_backend_cann_buffer_context *) buffer->context;
delete ctx; delete ctx;
+ delete buffer; + delete buffer;
} }
/** /**
@@ -1630,6 +1631,7 @@ static const char * ggml_backend_cann_host_buffer_name(ggml_backend_buffer_t buf @@ -1560,6 +1561,7 @@ static const char * ggml_backend_cann_host_buffer_name(ggml_backend_buffer_t buf
*/ */
static void ggml_backend_cann_host_buffer_free(ggml_backend_buffer_t buffer) { static void ggml_backend_cann_host_buffer_free(ggml_backend_buffer_t buffer) {
ACL_CHECK(aclrtFreeHost(buffer->context)); ACL_CHECK(aclrtFreeHost(buffer->context));
@ -84,10 +84,10 @@ index ad1adba6..7d44f74f 100755
/** /**
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index 856e9de2..c0b1e4c1 100644 index bc396b521..aefc6935e 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu --- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -567,6 +567,7 @@ struct ggml_backend_cuda_buffer_context { @@ -576,6 +576,7 @@ struct ggml_backend_cuda_buffer_context {
static void ggml_backend_cuda_buffer_free_buffer(ggml_backend_buffer_t buffer) { static void ggml_backend_cuda_buffer_free_buffer(ggml_backend_buffer_t buffer) {
ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context; ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
delete ctx; delete ctx;
@ -95,7 +95,7 @@ index 856e9de2..c0b1e4c1 100644
} }
static bool ggml_backend_buffer_is_cuda(ggml_backend_buffer_t buffer) { static bool ggml_backend_buffer_is_cuda(ggml_backend_buffer_t buffer) {
@@ -822,6 +823,7 @@ struct ggml_backend_cuda_split_buffer_context { @@ -831,6 +832,7 @@ struct ggml_backend_cuda_split_buffer_context {
static void ggml_backend_cuda_split_buffer_free_buffer(ggml_backend_buffer_t buffer) { static void ggml_backend_cuda_split_buffer_free_buffer(ggml_backend_buffer_t buffer) {
ggml_backend_cuda_split_buffer_context * ctx = (ggml_backend_cuda_split_buffer_context *)buffer->context; ggml_backend_cuda_split_buffer_context * ctx = (ggml_backend_cuda_split_buffer_context *)buffer->context;
delete ctx; delete ctx;
@ -103,7 +103,7 @@ index 856e9de2..c0b1e4c1 100644
} }
static void * ggml_backend_cuda_split_buffer_get_base(ggml_backend_buffer_t buffer) { static void * ggml_backend_cuda_split_buffer_get_base(ggml_backend_buffer_t buffer) {
@@ -1103,6 +1105,7 @@ static bool ggml_backend_buft_is_cuda_host(ggml_backend_buffer_type_t buft) { @@ -1112,6 +1114,7 @@ static bool ggml_backend_buft_is_cuda_host(ggml_backend_buffer_type_t buft) {
static void ggml_backend_cuda_host_buffer_free_buffer(ggml_backend_buffer_t buffer) { static void ggml_backend_cuda_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
CUDA_CHECK(cudaFreeHost(buffer->context)); CUDA_CHECK(cudaFreeHost(buffer->context));
@ -112,7 +112,7 @@ index 856e9de2..c0b1e4c1 100644
static void * ggml_cuda_host_malloc(size_t size) { static void * ggml_cuda_host_malloc(size_t size) {
diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp
index 7afc881f..bf096227 100644 index 7afc881fa..bf0962274 100644
--- a/ggml/src/ggml-metal/ggml-metal.cpp --- a/ggml/src/ggml-metal/ggml-metal.cpp
+++ b/ggml/src/ggml-metal/ggml-metal.cpp +++ b/ggml/src/ggml-metal/ggml-metal.cpp
@@ -25,6 +25,7 @@ static void ggml_backend_metal_buffer_shared_free_buffer(ggml_backend_buffer_t b @@ -25,6 +25,7 @@ static void ggml_backend_metal_buffer_shared_free_buffer(ggml_backend_buffer_t b
@ -132,10 +132,10 @@ index 7afc881f..bf096227 100644
static void * ggml_backend_metal_buffer_private_get_base(ggml_backend_buffer_t buffer) { static void * ggml_backend_metal_buffer_private_get_base(ggml_backend_buffer_t buffer) {
diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp
index 79d21487..38c75018 100644 index db33a4ab6..c42ee26e1 100644
--- a/ggml/src/ggml-opencl/ggml-opencl.cpp --- a/ggml/src/ggml-opencl/ggml-opencl.cpp
+++ b/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp
@@ -3212,6 +3212,7 @@ struct ggml_backend_opencl_buffer_context { @@ -3266,6 +3266,7 @@ struct ggml_backend_opencl_buffer_context {
static void ggml_backend_opencl_buffer_free_buffer(ggml_backend_buffer_t buffer) { static void ggml_backend_opencl_buffer_free_buffer(ggml_backend_buffer_t buffer) {
ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context;
delete ctx; delete ctx;
@ -144,7 +144,7 @@ index 79d21487..38c75018 100644
static void * ggml_backend_opencl_buffer_get_base(ggml_backend_buffer_t buffer) { static void * ggml_backend_opencl_buffer_get_base(ggml_backend_buffer_t buffer) {
diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp
index aad48d62..a46c0f52 100644 index a38df5a97..fd07e4a21 100644
--- a/ggml/src/ggml-rpc/ggml-rpc.cpp --- a/ggml/src/ggml-rpc/ggml-rpc.cpp
+++ b/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp
@@ -528,6 +528,7 @@ static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) { @@ -528,6 +528,7 @@ static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) {
@ -156,10 +156,10 @@ index aad48d62..a46c0f52 100644
static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) { static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp
index 45b8c216..4ec9a592 100644 index b695ba051..37e853120 100644
--- a/ggml/src/ggml-sycl/ggml-sycl.cpp --- a/ggml/src/ggml-sycl/ggml-sycl.cpp
+++ b/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp
@@ -334,6 +334,7 @@ ggml_backend_sycl_buffer_free_buffer(ggml_backend_buffer_t buffer) try { @@ -352,6 +352,7 @@ ggml_backend_sycl_buffer_free_buffer(ggml_backend_buffer_t buffer) try {
ggml_sycl_set_device(ctx->device); ggml_sycl_set_device(ctx->device);
delete ctx; delete ctx;
@ -167,7 +167,7 @@ index 45b8c216..4ec9a592 100644
} }
catch (sycl::exception const &exc) { catch (sycl::exception const &exc) {
std::cerr << exc.what() << "Exception caught at file:" << __FILE__ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
@@ -795,6 +796,7 @@ struct ggml_backend_sycl_split_buffer_context { @@ -813,6 +814,7 @@ struct ggml_backend_sycl_split_buffer_context {
static void ggml_backend_sycl_split_buffer_free_buffer(ggml_backend_buffer_t buffer) { static void ggml_backend_sycl_split_buffer_free_buffer(ggml_backend_buffer_t buffer) {
ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context; ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context;
delete ctx; delete ctx;
@ -175,7 +175,7 @@ index 45b8c216..4ec9a592 100644
} }
static void * ggml_backend_sycl_split_buffer_get_base(ggml_backend_buffer_t buffer) { static void * ggml_backend_sycl_split_buffer_get_base(ggml_backend_buffer_t buffer) {
@@ -1137,6 +1139,7 @@ static const char * ggml_backend_sycl_host_buffer_type_name(ggml_backend_buffer_ @@ -1155,6 +1157,7 @@ static const char * ggml_backend_sycl_host_buffer_type_name(ggml_backend_buffer_
static void ggml_backend_sycl_host_buffer_free_buffer(ggml_backend_buffer_t buffer) { static void ggml_backend_sycl_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
ggml_sycl_host_free(buffer->context); ggml_sycl_host_free(buffer->context);
@ -184,10 +184,10 @@ index 45b8c216..4ec9a592 100644
static ggml_backend_buffer_t ggml_backend_sycl_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { static ggml_backend_buffer_t ggml_backend_sycl_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index 3cd89c71..ed83236f 100644 index b783f7805..216dc167c 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -11600,6 +11600,7 @@ static void ggml_backend_vk_buffer_free_buffer(ggml_backend_buffer_t buffer) { @@ -11828,6 +11828,7 @@ static void ggml_backend_vk_buffer_free_buffer(ggml_backend_buffer_t buffer) {
ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context; ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context;
ggml_vk_destroy_buffer(ctx->dev_buffer); ggml_vk_destroy_buffer(ctx->dev_buffer);
delete ctx; delete ctx;
@ -195,7 +195,7 @@ index 3cd89c71..ed83236f 100644
} }
static void * ggml_backend_vk_buffer_get_base(ggml_backend_buffer_t buffer) { static void * ggml_backend_vk_buffer_get_base(ggml_backend_buffer_t buffer) {
@@ -11743,6 +11744,7 @@ static const char * ggml_backend_vk_host_buffer_name(ggml_backend_buffer_t buffe @@ -11971,6 +11972,7 @@ static const char * ggml_backend_vk_host_buffer_name(ggml_backend_buffer_t buffe
static void ggml_backend_vk_host_buffer_free_buffer(ggml_backend_buffer_t buffer) { static void ggml_backend_vk_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
VK_LOG_MEMORY("ggml_backend_vk_host_buffer_free_buffer()"); VK_LOG_MEMORY("ggml_backend_vk_host_buffer_free_buffer()");
ggml_vk_host_free(vk_instance.devices[0], buffer->context); ggml_vk_host_free(vk_instance.devices[0], buffer->context);

View File

@ -10,7 +10,7 @@ logs instead of throwing an error
1 file changed, 3 insertions(+), 11 deletions(-) 1 file changed, 3 insertions(+), 11 deletions(-)
diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp
index 7fffd171..0b6edaf4 100644 index 639fecbd3..a7ce6f8e1 100644
--- a/src/llama-vocab.cpp --- a/src/llama-vocab.cpp
+++ b/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp
@@ -1812,16 +1812,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { @@ -1812,16 +1812,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
@ -31,7 +31,7 @@ index 7fffd171..0b6edaf4 100644
pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT; pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
} else if ( } else if (
tokenizer_pre == "llama3" || tokenizer_pre == "llama3" ||
@@ -1992,7 +1983,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { @@ -1993,7 +1984,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
pre_type = LLAMA_VOCAB_PRE_TYPE_GROK_2; pre_type = LLAMA_VOCAB_PRE_TYPE_GROK_2;
clean_spaces = false; clean_spaces = false;
} else { } else {

View File

@ -10,7 +10,7 @@ filesystems for paths that include wide characters
1 file changed, 39 insertions(+) 1 file changed, 39 insertions(+)
diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp
index 98e68af2..6699b75a 100644 index f2abf8852..c984e6282 100644
--- a/tools/mtmd/clip.cpp --- a/tools/mtmd/clip.cpp
+++ b/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp
@@ -28,6 +28,19 @@ @@ -28,6 +28,19 @@
@ -33,7 +33,7 @@ index 98e68af2..6699b75a 100644
struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callback_default, NULL}; struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callback_default, NULL};
enum ffn_op_type { enum ffn_op_type {
@@ -2762,7 +2775,29 @@ struct clip_model_loader { @@ -2774,7 +2787,29 @@ struct clip_model_loader {
{ {
std::vector<uint8_t> read_buf; std::vector<uint8_t> read_buf;
@ -63,7 +63,7 @@ index 98e68af2..6699b75a 100644
if (!fin) { if (!fin) {
throw std::runtime_error(string_format("%s: failed to open %s\n", __func__, fname.c_str())); throw std::runtime_error(string_format("%s: failed to open %s\n", __func__, fname.c_str()));
} }
@@ -2789,7 +2824,11 @@ struct clip_model_loader { @@ -2801,7 +2836,11 @@ struct clip_model_loader {
ggml_backend_tensor_set(cur, read_buf.data(), 0, num_bytes); ggml_backend_tensor_set(cur, read_buf.data(), 0, num_bytes);
} }
} }

View File

@ -15,10 +15,10 @@ adds support for the Solar Pro architecture
7 files changed, 248 insertions(+), 1 deletion(-) 7 files changed, 248 insertions(+), 1 deletion(-)
diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp
index 869e4dcc..9f6b6ad2 100644 index 8ca769c5f..ab262ec0c 100644
--- a/src/llama-arch.cpp --- a/src/llama-arch.cpp
+++ b/src/llama-arch.cpp +++ b/src/llama-arch.cpp
@@ -81,6 +81,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = { @@ -82,6 +82,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_GRANITE_MOE, "granitemoe" }, { LLM_ARCH_GRANITE_MOE, "granitemoe" },
{ LLM_ARCH_GRANITE_HYBRID, "granitehybrid" }, { LLM_ARCH_GRANITE_HYBRID, "granitehybrid" },
{ LLM_ARCH_CHAMELEON, "chameleon" }, { LLM_ARCH_CHAMELEON, "chameleon" },
@ -26,7 +26,7 @@ index 869e4dcc..9f6b6ad2 100644
{ LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" }, { LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
{ LLM_ARCH_PLM, "plm" }, { LLM_ARCH_PLM, "plm" },
{ LLM_ARCH_BAILINGMOE, "bailingmoe" }, { LLM_ARCH_BAILINGMOE, "bailingmoe" },
@@ -179,6 +180,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = { @@ -183,6 +184,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" }, { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
{ LLM_KV_ATTENTION_OUTPUT_SCALE, "%s.attention.output_scale" }, { LLM_KV_ATTENTION_OUTPUT_SCALE, "%s.attention.output_scale" },
{ LLM_KV_ATTENTION_TEMPERATURE_LENGTH, "%s.attention.temperature_length" }, { LLM_KV_ATTENTION_TEMPERATURE_LENGTH, "%s.attention.temperature_length" },
@ -34,7 +34,7 @@ index 869e4dcc..9f6b6ad2 100644
{ LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" }, { LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
{ LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" }, { LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },
@@ -1893,6 +1895,24 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N @@ -1901,6 +1903,24 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
}, },
}, },
@ -59,7 +59,7 @@ index 869e4dcc..9f6b6ad2 100644
{ {
LLM_ARCH_WAVTOKENIZER_DEC, LLM_ARCH_WAVTOKENIZER_DEC,
{ {
@@ -2429,6 +2449,7 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = { @@ -2469,6 +2489,7 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_LAUREL_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_LAUREL_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
// this tensor is loaded for T5, but never used // this tensor is loaded for T5, but never used
{LLM_TENSOR_DEC_CROSS_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}}, {LLM_TENSOR_DEC_CROSS_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}},
@ -68,10 +68,10 @@ index 869e4dcc..9f6b6ad2 100644
{LLM_TENSOR_POS_NET_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_POS_NET_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_POS_NET_NORM1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_POS_NET_NORM1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
diff --git a/src/llama-arch.h b/src/llama-arch.h diff --git a/src/llama-arch.h b/src/llama-arch.h
index c3ae7165..dc7a362a 100644 index dea725c1a..ea2b4ffb9 100644
--- a/src/llama-arch.h --- a/src/llama-arch.h
+++ b/src/llama-arch.h +++ b/src/llama-arch.h
@@ -85,6 +85,7 @@ enum llm_arch { @@ -86,6 +86,7 @@ enum llm_arch {
LLM_ARCH_GRANITE_MOE, LLM_ARCH_GRANITE_MOE,
LLM_ARCH_GRANITE_HYBRID, LLM_ARCH_GRANITE_HYBRID,
LLM_ARCH_CHAMELEON, LLM_ARCH_CHAMELEON,
@ -79,7 +79,7 @@ index c3ae7165..dc7a362a 100644
LLM_ARCH_WAVTOKENIZER_DEC, LLM_ARCH_WAVTOKENIZER_DEC,
LLM_ARCH_PLM, LLM_ARCH_PLM,
LLM_ARCH_BAILINGMOE, LLM_ARCH_BAILINGMOE,
@@ -183,6 +184,7 @@ enum llm_kv { @@ -187,6 +188,7 @@ enum llm_kv {
LLM_KV_ATTENTION_SCALE, LLM_KV_ATTENTION_SCALE,
LLM_KV_ATTENTION_OUTPUT_SCALE, LLM_KV_ATTENTION_OUTPUT_SCALE,
LLM_KV_ATTENTION_TEMPERATURE_LENGTH, LLM_KV_ATTENTION_TEMPERATURE_LENGTH,
@ -87,7 +87,7 @@ index c3ae7165..dc7a362a 100644
LLM_KV_ATTENTION_KEY_LENGTH_MLA, LLM_KV_ATTENTION_KEY_LENGTH_MLA,
LLM_KV_ATTENTION_VALUE_LENGTH_MLA, LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
@@ -432,6 +434,7 @@ enum llm_tensor { @@ -436,6 +438,7 @@ enum llm_tensor {
LLM_TENSOR_ENC_OUTPUT_NORM, LLM_TENSOR_ENC_OUTPUT_NORM,
LLM_TENSOR_CLS, LLM_TENSOR_CLS,
LLM_TENSOR_CLS_OUT, LLM_TENSOR_CLS_OUT,
@ -96,7 +96,7 @@ index c3ae7165..dc7a362a 100644
LLM_TENSOR_CONVNEXT_DW, LLM_TENSOR_CONVNEXT_DW,
LLM_TENSOR_CONVNEXT_NORM, LLM_TENSOR_CONVNEXT_NORM,
diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp
index db65d69e..b6bf6bbf 100644 index db65d69ea..b6bf6bbf2 100644
--- a/src/llama-hparams.cpp --- a/src/llama-hparams.cpp
+++ b/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp
@@ -151,6 +151,14 @@ uint32_t llama_hparams::n_pos_per_embd() const { @@ -151,6 +151,14 @@ uint32_t llama_hparams::n_pos_per_embd() const {
@ -115,7 +115,7 @@ index db65d69e..b6bf6bbf 100644
if (il < n_layer) { if (il < n_layer) {
return swa_layers[il]; return swa_layers[il];
diff --git a/src/llama-hparams.h b/src/llama-hparams.h diff --git a/src/llama-hparams.h b/src/llama-hparams.h
index 4e7f73ec..80582728 100644 index 6fcf91b7d..24569a258 100644
--- a/src/llama-hparams.h --- a/src/llama-hparams.h
+++ b/src/llama-hparams.h +++ b/src/llama-hparams.h
@@ -64,6 +64,8 @@ struct llama_hparams { @@ -64,6 +64,8 @@ struct llama_hparams {
@ -127,7 +127,7 @@ index 4e7f73ec..80582728 100644
uint32_t n_layer_dense_lead = 0; uint32_t n_layer_dense_lead = 0;
uint32_t n_lora_q = 0; uint32_t n_lora_q = 0;
uint32_t n_lora_kv = 0; uint32_t n_lora_kv = 0;
@@ -248,6 +250,9 @@ struct llama_hparams { @@ -250,6 +252,9 @@ struct llama_hparams {
uint32_t n_pos_per_embd() const; uint32_t n_pos_per_embd() const;
@ -138,7 +138,7 @@ index 4e7f73ec..80582728 100644
bool has_kv(uint32_t il) const; bool has_kv(uint32_t il) const;
diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp
index aa3a65f8..ee303bd5 100644 index aa3a65f87..ee303bd58 100644
--- a/src/llama-model-loader.cpp --- a/src/llama-model-loader.cpp
+++ b/src/llama-model-loader.cpp +++ b/src/llama-model-loader.cpp
@@ -466,7 +466,7 @@ namespace GGUFMeta { @@ -466,7 +466,7 @@ namespace GGUFMeta {
@ -151,10 +151,10 @@ index aa3a65f8..ee303bd5 100644
llama_model_loader::llama_model_loader( llama_model_loader::llama_model_loader(
const std::string & fname, const std::string & fname,
diff --git a/src/llama-model.cpp b/src/llama-model.cpp diff --git a/src/llama-model.cpp b/src/llama-model.cpp
index 36d495d6..74e1d162 100644 index 2a83d6627..54621ea39 100644
--- a/src/llama-model.cpp --- a/src/llama-model.cpp
+++ b/src/llama-model.cpp +++ b/src/llama-model.cpp
@@ -1865,6 +1865,21 @@ void llama_model::load_hparams(llama_model_loader & ml) { @@ -1890,6 +1890,21 @@ void llama_model::load_hparams(llama_model_loader & ml) {
default: type = LLM_TYPE_UNKNOWN; default: type = LLM_TYPE_UNKNOWN;
} }
} break; } break;
@ -176,7 +176,7 @@ index 36d495d6..74e1d162 100644
case LLM_ARCH_WAVTOKENIZER_DEC: case LLM_ARCH_WAVTOKENIZER_DEC:
{ {
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@@ -5170,6 +5185,34 @@ bool llama_model::load_tensors(llama_model_loader & ml) { @@ -5224,6 +5239,34 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
@ -211,7 +211,7 @@ index 36d495d6..74e1d162 100644
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
@@ -16392,6 +16435,165 @@ struct llm_build_granite_hybrid : public llm_graph_context_mamba { @@ -16515,6 +16558,165 @@ struct llm_build_granite_hybrid : public llm_graph_context_mamba {
} }
}; };
@ -377,7 +377,7 @@ index 36d495d6..74e1d162 100644
// ref: https://github.com/facebookresearch/chameleon // ref: https://github.com/facebookresearch/chameleon
// based on the original build_llama() function, changes: // based on the original build_llama() function, changes:
// * qk-norm // * qk-norm
@@ -19827,6 +20029,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { @@ -20096,6 +20298,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
{ {
llm = std::make_unique<llm_build_chameleon>(*this, params); llm = std::make_unique<llm_build_chameleon>(*this, params);
} break; } break;
@ -388,7 +388,7 @@ index 36d495d6..74e1d162 100644
case LLM_ARCH_WAVTOKENIZER_DEC: case LLM_ARCH_WAVTOKENIZER_DEC:
{ {
llm = std::make_unique<llm_build_wavtokenizer_dec>(*this, params); llm = std::make_unique<llm_build_wavtokenizer_dec>(*this, params);
@@ -20057,6 +20263,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { @@ -20331,6 +20537,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_GRANITE_MOE: case LLM_ARCH_GRANITE_MOE:
case LLM_ARCH_GRANITE_HYBRID: case LLM_ARCH_GRANITE_HYBRID:
case LLM_ARCH_CHAMELEON: case LLM_ARCH_CHAMELEON:
@ -397,7 +397,7 @@ index 36d495d6..74e1d162 100644
case LLM_ARCH_NEO_BERT: case LLM_ARCH_NEO_BERT:
case LLM_ARCH_SMOLLM3: case LLM_ARCH_SMOLLM3:
diff --git a/src/llama-model.h b/src/llama-model.h diff --git a/src/llama-model.h b/src/llama-model.h
index 7f48662f..ec3fbd33 100644 index 248f85410..4a7924aaa 100644
--- a/src/llama-model.h --- a/src/llama-model.h
+++ b/src/llama-model.h +++ b/src/llama-model.h
@@ -76,6 +76,7 @@ enum llm_type { @@ -76,6 +76,7 @@ enum llm_type {
@ -408,7 +408,7 @@ index 7f48662f..ec3fbd33 100644
LLM_TYPE_27B, LLM_TYPE_27B,
LLM_TYPE_30B, LLM_TYPE_30B,
LLM_TYPE_32B, LLM_TYPE_32B,
@@ -387,6 +388,8 @@ struct llama_layer { @@ -390,6 +391,8 @@ struct llama_layer {
struct ggml_tensor * ffn_act_beta = nullptr; struct ggml_tensor * ffn_act_beta = nullptr;
struct ggml_tensor * ffn_act_eps = nullptr; struct ggml_tensor * ffn_act_eps = nullptr;

View File

@ -12,7 +12,7 @@ regex
2 files changed, 22 insertions(+), 1 deletion(-) 2 files changed, 22 insertions(+), 1 deletion(-)
diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp
index 0b6edaf4..3de95c67 100644 index a7ce6f8e1..8064dc197 100644
--- a/src/llama-vocab.cpp --- a/src/llama-vocab.cpp
+++ b/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp
@@ -299,7 +299,7 @@ struct llm_tokenizer_bpe : llm_tokenizer { @@ -299,7 +299,7 @@ struct llm_tokenizer_bpe : llm_tokenizer {
@ -25,7 +25,7 @@ index 0b6edaf4..3de95c67 100644
"\\s+$", "\\s+$",
"[一-龥ࠀ-一가-퟿]+", "[一-龥ࠀ-一가-퟿]+",
diff --git a/src/unicode.cpp b/src/unicode.cpp diff --git a/src/unicode.cpp b/src/unicode.cpp
index 65f36651..ce336a22 100644 index 65f366517..ce336a228 100644
--- a/src/unicode.cpp --- a/src/unicode.cpp
+++ b/src/unicode.cpp +++ b/src/unicode.cpp
@@ -2,6 +2,11 @@ @@ -2,6 +2,11 @@

View File

@ -8,7 +8,7 @@ Subject: [PATCH] maintain ordering for rules for grammar
1 file changed, 1 insertion(+), 1 deletion(-) 1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/common/json-schema-to-grammar.cpp b/common/json-schema-to-grammar.cpp diff --git a/common/json-schema-to-grammar.cpp b/common/json-schema-to-grammar.cpp
index db1f0b23..f4de7e34 100644 index dd9b51a9e..d88f43209 100644
--- a/common/json-schema-to-grammar.cpp --- a/common/json-schema-to-grammar.cpp
+++ b/common/json-schema-to-grammar.cpp +++ b/common/json-schema-to-grammar.cpp
@@ -308,7 +308,7 @@ private: @@ -308,7 +308,7 @@ private:

View File

@ -11,10 +11,10 @@ with the fastest acceleration is loaded
1 file changed, 13 insertions(+), 8 deletions(-) 1 file changed, 13 insertions(+), 8 deletions(-)
diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp
index 136afec7..f794d9cf 100644 index e96b5c403..a55d9b280 100644
--- a/ggml/src/ggml-backend-reg.cpp --- a/ggml/src/ggml-backend-reg.cpp
+++ b/ggml/src/ggml-backend-reg.cpp +++ b/ggml/src/ggml-backend-reg.cpp
@@ -175,7 +175,7 @@ struct ggml_backend_reg_entry { @@ -179,7 +179,7 @@ struct ggml_backend_reg_entry {
struct ggml_backend_registry { struct ggml_backend_registry {
std::vector<ggml_backend_reg_entry> backends; std::vector<ggml_backend_reg_entry> backends;
@ -23,7 +23,7 @@ index 136afec7..f794d9cf 100644
ggml_backend_registry() { ggml_backend_registry() {
#ifdef GGML_USE_CUDA #ifdef GGML_USE_CUDA
@@ -223,7 +223,7 @@ struct ggml_backend_registry { @@ -230,7 +230,7 @@ struct ggml_backend_registry {
} }
} }
@ -32,7 +32,7 @@ index 136afec7..f794d9cf 100644
if (!reg) { if (!reg) {
return; return;
} }
@@ -234,15 +234,20 @@ struct ggml_backend_registry { @@ -241,15 +241,20 @@ struct ggml_backend_registry {
#endif #endif
backends.push_back({ reg, std::move(handle) }); backends.push_back({ reg, std::move(handle) });
for (size_t i = 0; i < ggml_backend_reg_dev_count(reg); i++) { for (size_t i = 0; i < ggml_backend_reg_dev_count(reg); i++) {
@ -56,7 +56,7 @@ index 136afec7..f794d9cf 100644
} }
ggml_backend_reg_t load_backend(const fs::path & path, bool silent) { ggml_backend_reg_t load_backend(const fs::path & path, bool silent) {
@@ -286,7 +291,7 @@ struct ggml_backend_registry { @@ -293,7 +298,7 @@ struct ggml_backend_registry {
GGML_LOG_INFO("%s: loaded %s backend from %s\n", __func__, ggml_backend_reg_name(reg), path_str(path).c_str()); GGML_LOG_INFO("%s: loaded %s backend from %s\n", __func__, ggml_backend_reg_name(reg), path_str(path).c_str());
@ -65,7 +65,7 @@ index 136afec7..f794d9cf 100644
return reg; return reg;
} }
@@ -309,7 +314,7 @@ struct ggml_backend_registry { @@ -316,7 +321,7 @@ struct ggml_backend_registry {
// remove devices // remove devices
devices.erase( devices.erase(
std::remove_if(devices.begin(), devices.end(), std::remove_if(devices.begin(), devices.end(),
@ -74,7 +74,7 @@ index 136afec7..f794d9cf 100644
devices.end()); devices.end());
// remove backend // remove backend
@@ -367,7 +372,7 @@ size_t ggml_backend_dev_count() { @@ -374,7 +379,7 @@ size_t ggml_backend_dev_count() {
ggml_backend_dev_t ggml_backend_dev_get(size_t index) { ggml_backend_dev_t ggml_backend_dev_get(size_t index) {
GGML_ASSERT(index < ggml_backend_dev_count()); GGML_ASSERT(index < ggml_backend_dev_count());

View File

@ -8,10 +8,10 @@ Subject: [PATCH] add phony target ggml-cpu for all cpu variants
1 file changed, 2 insertions(+) 1 file changed, 2 insertions(+)
diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt
index 892c2331..09fdf5fc 100644 index ba281b8e6..ead235878 100644
--- a/ggml/src/CMakeLists.txt --- a/ggml/src/CMakeLists.txt
+++ b/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt
@@ -310,6 +310,7 @@ function(ggml_add_cpu_backend_variant tag_name) @@ -314,6 +314,7 @@ function(ggml_add_cpu_backend_variant tag_name)
endif() endif()
ggml_add_cpu_backend_variant_impl(${tag_name}) ggml_add_cpu_backend_variant_impl(${tag_name})
@ -19,7 +19,7 @@ index 892c2331..09fdf5fc 100644
endfunction() endfunction()
ggml_add_backend(CPU) ggml_add_backend(CPU)
@@ -320,6 +321,7 @@ if (GGML_CPU_ALL_VARIANTS) @@ -324,6 +325,7 @@ if (GGML_CPU_ALL_VARIANTS)
elseif (GGML_CPU_ARM_ARCH) elseif (GGML_CPU_ARM_ARCH)
message(FATAL_ERROR "Cannot use both GGML_CPU_ARM_ARCH and GGML_CPU_ALL_VARIANTS") message(FATAL_ERROR "Cannot use both GGML_CPU_ARM_ARCH and GGML_CPU_ALL_VARIANTS")
endif() endif()

View File

@ -9,10 +9,10 @@ disable amx as it reduces performance on some systems
1 file changed, 4 deletions(-) 1 file changed, 4 deletions(-)
diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt
index 09fdf5fc..0609c650 100644 index ead235878..f9a6587f1 100644
--- a/ggml/src/CMakeLists.txt --- a/ggml/src/CMakeLists.txt
+++ b/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt
@@ -330,10 +330,6 @@ if (GGML_CPU_ALL_VARIANTS) @@ -334,10 +334,6 @@ if (GGML_CPU_ALL_VARIANTS)
ggml_add_cpu_backend_variant(skylakex SSE42 AVX F16C AVX2 BMI2 FMA AVX512) ggml_add_cpu_backend_variant(skylakex SSE42 AVX F16C AVX2 BMI2 FMA AVX512)
ggml_add_cpu_backend_variant(icelake SSE42 AVX F16C AVX2 BMI2 FMA AVX512 AVX512_VBMI AVX512_VNNI) ggml_add_cpu_backend_variant(icelake SSE42 AVX F16C AVX2 BMI2 FMA AVX512 AVX512_VBMI AVX512_VNNI)
ggml_add_cpu_backend_variant(alderlake SSE42 AVX F16C AVX2 BMI2 FMA AVX_VNNI) ggml_add_cpu_backend_variant(alderlake SSE42 AVX F16C AVX2 BMI2 FMA AVX_VNNI)

View File

@ -13,7 +13,7 @@ such as vocab fields
3 files changed, 7 insertions(+), 5 deletions(-) 3 files changed, 7 insertions(+), 5 deletions(-)
diff --git a/ggml/include/gguf.h b/ggml/include/gguf.h diff --git a/ggml/include/gguf.h b/ggml/include/gguf.h
index 79ee2020..3efb22f0 100644 index 79ee20206..3efb22f01 100644
--- a/ggml/include/gguf.h --- a/ggml/include/gguf.h
+++ b/ggml/include/gguf.h +++ b/ggml/include/gguf.h
@@ -114,6 +114,7 @@ extern "C" { @@ -114,6 +114,7 @@ extern "C" {
@ -25,7 +25,7 @@ index 79ee2020..3efb22f0 100644
// get ith C string from array with given key_id // get ith C string from array with given key_id
GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int64_t key_id, size_t i); GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int64_t key_id, size_t i);
diff --git a/ggml/src/gguf.cpp b/ggml/src/gguf.cpp diff --git a/ggml/src/gguf.cpp b/ggml/src/gguf.cpp
index 8cc4ef1c..d950dbdf 100644 index 8cc4ef1cf..d950dbdf5 100644
--- a/ggml/src/gguf.cpp --- a/ggml/src/gguf.cpp
+++ b/ggml/src/gguf.cpp +++ b/ggml/src/gguf.cpp
@@ -805,10 +805,14 @@ enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int64_t key_id @@ -805,10 +805,14 @@ enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int64_t key_id
@ -53,7 +53,7 @@ index 8cc4ef1c..d950dbdf 100644
} }
diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp
index 3de95c67..217ede47 100644 index 8064dc197..31f49801c 100644
--- a/src/llama-vocab.cpp --- a/src/llama-vocab.cpp
+++ b/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp
@@ -1768,9 +1768,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { @@ -1768,9 +1768,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {

View File

@ -8,7 +8,7 @@ Subject: [PATCH] ollama debug tensor
1 file changed, 6 insertions(+) 1 file changed, 6 insertions(+)
diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c
index ba2a36d9..99509b0c 100644 index 9ec485cfa..4b2f8b7bd 100644
--- a/ggml/src/ggml-cpu/ggml-cpu.c --- a/ggml/src/ggml-cpu/ggml-cpu.c
+++ b/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c
@@ -15,6 +15,8 @@ @@ -15,6 +15,8 @@
@ -20,7 +20,7 @@ index ba2a36d9..99509b0c 100644
#if defined(_MSC_VER) || defined(__MINGW32__) #if defined(_MSC_VER) || defined(__MINGW32__)
#include <malloc.h> // using malloc.h with MSC/MINGW #include <malloc.h> // using malloc.h with MSC/MINGW
#elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__) #elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__)
@@ -2887,6 +2889,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { @@ -2891,6 +2893,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
ggml_compute_forward(&params, node); ggml_compute_forward(&params, node);

View File

@ -10,7 +10,7 @@ Subject: [PATCH] add ollama vocab for grammar support
3 files changed, 58 insertions(+), 9 deletions(-) 3 files changed, 58 insertions(+), 9 deletions(-)
diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp
index bed706bb..b51cee09 100644 index bed706bb2..b51cee090 100644
--- a/src/llama-grammar.cpp --- a/src/llama-grammar.cpp
+++ b/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp
@@ -907,6 +907,7 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack( @@ -907,6 +907,7 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
@ -137,7 +137,7 @@ index bed706bb..b51cee09 100644
+ } + }
+} +}
diff --git a/src/llama-grammar.h b/src/llama-grammar.h diff --git a/src/llama-grammar.h b/src/llama-grammar.h
index f8c291de..2a3a62db 100644 index f8c291de9..2a3a62db3 100644
--- a/src/llama-grammar.h --- a/src/llama-grammar.h
+++ b/src/llama-grammar.h +++ b/src/llama-grammar.h
@@ -6,8 +6,19 @@ @@ -6,8 +6,19 @@
@ -184,7 +184,7 @@ index f8c291de..2a3a62db 100644
const char * grammar_root, const char * grammar_root,
bool lazy, bool lazy,
diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp
index 55d2e355..da34526b 100644 index 55d2e355f..da34526b1 100644
--- a/src/llama-sampling.cpp --- a/src/llama-sampling.cpp
+++ b/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp
@@ -1563,7 +1563,7 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) { @@ -1563,7 +1563,7 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {

View File

@ -4,15 +4,15 @@ Date: Thu, 1 May 2025 13:45:12 -0700
Subject: [PATCH] add argsort and cuda copy for i32 Subject: [PATCH] add argsort and cuda copy for i32
--- ---
ggml/src/ggml-cpu/ops.cpp | 43 +++++++++++ ggml/src/ggml-cpu/ops.cpp | 43 ++++++++++
ggml/src/ggml-cuda/argsort.cu | 102 ++++++++++++++++++++++++++- ggml/src/ggml-cuda/argsort.cu | 122 ++++++++++++++++++++++++---
ggml/src/ggml-cuda/cpy-utils.cuh | 6 ++ ggml/src/ggml-cuda/cpy-utils.cuh | 6 ++
ggml/src/ggml-cuda/cpy.cu | 43 +++++++++++ ggml/src/ggml-cuda/cpy.cu | 40 +++++++++
ggml/src/ggml-metal/ggml-metal.metal | 64 +++++++++++++++++ ggml/src/ggml-metal/ggml-metal.metal | 64 ++++++++++++++
5 files changed, 256 insertions(+), 2 deletions(-) 5 files changed, 263 insertions(+), 12 deletions(-)
diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp
index 1c43865f..31478dd8 100644 index b52f0f847..902fdad69 100644
--- a/ggml/src/ggml-cpu/ops.cpp --- a/ggml/src/ggml-cpu/ops.cpp
+++ b/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp
@@ -7889,6 +7889,45 @@ static void ggml_compute_forward_argsort_f32( @@ -7889,6 +7889,45 @@ static void ggml_compute_forward_argsort_f32(
@ -73,10 +73,10 @@ index 1c43865f..31478dd8 100644
{ {
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
diff --git a/ggml/src/ggml-cuda/argsort.cu b/ggml/src/ggml-cuda/argsort.cu diff --git a/ggml/src/ggml-cuda/argsort.cu b/ggml/src/ggml-cuda/argsort.cu
index 607ded85..53b02634 100644 index 6e7b90d42..08dd30525 100644
--- a/ggml/src/ggml-cuda/argsort.cu --- a/ggml/src/ggml-cuda/argsort.cu
+++ b/ggml/src/ggml-cuda/argsort.cu +++ b/ggml/src/ggml-cuda/argsort.cu
@@ -85,13 +85,107 @@ static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, co @@ -168,13 +168,107 @@ static void argsort_f32_i32_cuda_bitonic(const float * x,
} }
} }
@ -185,19 +185,42 @@ index 607ded85..53b02634 100644
GGML_ASSERT( dst->type == GGML_TYPE_I32); GGML_ASSERT( dst->type == GGML_TYPE_I32);
GGML_ASSERT(ggml_is_contiguous(src0)); GGML_ASSERT(ggml_is_contiguous(src0));
@@ -100,5 +194,9 @@ void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { @@ -183,18 +277,22 @@ void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0]; enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
- argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, order, stream); -#ifdef GGML_CUDA_USE_CUB
- const int ncols_pad = next_power_of_2(ncols);
- const size_t shared_mem = ncols_pad * sizeof(int);
- const size_t max_shared_mem = ggml_cuda_info().devices[ggml_cuda_get_device()].smpb;
-
- if (shared_mem > max_shared_mem || ncols > 1024) {
- ggml_cuda_pool & pool = ctx.pool();
- argsort_f32_i32_cuda_cub(pool, src0_d, (int *) dst_d, ncols, nrows, order, stream);
+ if (src0->type == GGML_TYPE_I32) { + if (src0->type == GGML_TYPE_I32) {
+ argsort_i32_i32_cuda((const int32_t *)src0_d, (int *)dst_d, ncols, nrows, order, stream); + argsort_i32_i32_cuda((const int32_t *)src0_d, (int *)dst_d, ncols, nrows, order, stream);
} else {
- argsort_f32_i32_cuda_bitonic(src0_d, (int *) dst_d, ncols, nrows, order, stream);
- }
+#ifdef GGML_CUDA_USE_CUB
+ const int ncols_pad = next_power_of_2(ncols);
+ const size_t shared_mem = ncols_pad * sizeof(int);
+ const size_t max_shared_mem = ggml_cuda_info().devices[ggml_cuda_get_device()].smpb;
+
+ if (shared_mem > max_shared_mem || ncols > 1024) {
+ ggml_cuda_pool & pool = ctx.pool();
+ argsort_f32_i32_cuda_cub(pool, src0_d, (int *) dst_d, ncols, nrows, order, stream);
+ } else { + } else {
+ argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, order, stream); + argsort_f32_i32_cuda_bitonic(src0_d, (int *) dst_d, ncols, nrows, order, stream);
+ }
#else
- argsort_f32_i32_cuda_bitonic(src0_d, (int *) dst_d, ncols, nrows, order, stream);
+ argsort_f32_i32_cuda_bitonic(src0_d, (int *) dst_d, ncols, nrows, order, stream);
#endif
+ } + }
} }
diff --git a/ggml/src/ggml-cuda/cpy-utils.cuh b/ggml/src/ggml-cuda/cpy-utils.cuh diff --git a/ggml/src/ggml-cuda/cpy-utils.cuh b/ggml/src/ggml-cuda/cpy-utils.cuh
index e621cb98..597c0c8b 100644 index e621cb981..597c0c8b3 100644
--- a/ggml/src/ggml-cuda/cpy-utils.cuh --- a/ggml/src/ggml-cuda/cpy-utils.cuh
+++ b/ggml/src/ggml-cuda/cpy-utils.cuh +++ b/ggml/src/ggml-cuda/cpy-utils.cuh
@@ -215,3 +215,9 @@ template<typename src_t, typename dst_t> @@ -215,3 +215,9 @@ template<typename src_t, typename dst_t>
@ -211,19 +234,18 @@ index e621cb98..597c0c8b 100644
+ *dst = *src; + *dst = *src;
+} +}
diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu
index 746f4396..911220e9 100644 index 12d5bf776..a0e34030e 100644
--- a/ggml/src/ggml-cuda/cpy.cu --- a/ggml/src/ggml-cuda/cpy.cu
+++ b/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu
@@ -277,6 +277,47 @@ static void ggml_cpy_f32_iq4_nl_cuda( @@ -251,6 +251,43 @@ static void ggml_cpy_f32_iq4_nl_cuda(
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
} }
+template <cpy_kernel_t cpy_1> +template <cpy_kernel_t cpy_1>
+static __global__ void cpy_i32_i32( +static __global__ void cpy_i32_i32(
+ const char *cx, char *cdst, const int ne, + const char *cx, char *cdst, const int ne,
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, + const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
+ const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, + const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+ cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
+ +
+ const int64_t i = blockDim.x * blockIdx.x + threadIdx.x; + const int64_t i = blockDim.x * blockIdx.x + threadIdx.x;
+ +
@ -243,39 +265,37 @@ index 746f4396..911220e9 100644
+ const int64_t i10 = i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11 - i11 * ne10; + const int64_t i10 = i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11 - i11 * ne10;
+ const int64_t dst_offset = i10 * nb10 + i11 * nb11 + i12 * nb12 + i13 * nb13; + const int64_t dst_offset = i10 * nb10 + i11 * nb11 + i12 * nb12 + i13 * nb13;
+ +
+ char * cdst_ptr = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index] : cdst; + cpy_1(cx + x_offset, cdst + dst_offset);
+ cpy_1(cx + x_offset, cdst_ptr + dst_offset);
+} +}
+ +
+
+static void ggml_cpy_i32_i32_cuda( +static void ggml_cpy_i32_i32_cuda(
+ const char * cx, char * cdst, const int ne, + const char * cx, char * cdst, const int ne,
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, + const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
+ const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, + const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+ cudaStream_t stream, char ** cdst_indirect, int graph_cpynode_index) {
+ +
+ const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; + const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
+ cpy_i32_i32<cpy_1_i32_i32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>> + cpy_i32_i32<cpy_1_i32_i32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
+ (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, stream, cdst_indirect, graph_cpynode_index); + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, stream);
+} +}
+ +
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection_for_this_node) { void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) {
const int64_t ne = ggml_nelements(src0); const int64_t ne = ggml_nelements(src0);
GGML_ASSERT(ne == ggml_nelements(src1)); GGML_ASSERT(ne == ggml_nelements(src1));
@@ -372,6 +413,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg @@ -332,6 +369,9 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
ggml_cpy_flt_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); ggml_cpy_flt_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
ggml_cpy_flt_cuda<half, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); ggml_cpy_flt_cuda<half, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+ } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) { + } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) {
+ ggml_cpy_i32_i32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + // TODO consider converting to template
+ ggml_cpy_i32_i32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) { } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) { } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
index 74a9aa99..375a0c7f 100644 index 2c2f01415..50b8071de 100644
--- a/ggml/src/ggml-metal/ggml-metal.metal --- a/ggml/src/ggml-metal/ggml-metal.metal
+++ b/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal
@@ -4346,8 +4346,72 @@ kernel void kernel_argsort_f32_i32( @@ -4467,8 +4467,72 @@ kernel void kernel_argsort_f32_i32(
} }
} }

View File

@ -11,7 +11,7 @@ Subject: [PATCH] graph memory reporting on failure
4 files changed, 40 insertions(+), 3 deletions(-) 4 files changed, 40 insertions(+), 3 deletions(-)
diff --git a/ggml/include/ggml-alloc.h b/ggml/include/ggml-alloc.h diff --git a/ggml/include/ggml-alloc.h b/ggml/include/ggml-alloc.h
index 2cb150fd..7ab3f019 100644 index 2cb150fd2..7ab3f0192 100644
--- a/ggml/include/ggml-alloc.h --- a/ggml/include/ggml-alloc.h
+++ b/ggml/include/ggml-alloc.h +++ b/ggml/include/ggml-alloc.h
@@ -65,6 +65,7 @@ GGML_API bool ggml_gallocr_reserve_n( @@ -65,6 +65,7 @@ GGML_API bool ggml_gallocr_reserve_n(
@ -23,7 +23,7 @@ index 2cb150fd..7ab3f019 100644
// Utils // Utils
// Create a buffer and allocate all the tensors in a ggml_context // Create a buffer and allocate all the tensors in a ggml_context
diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h
index f1b74078..c54ff98b 100644 index f1b740785..c54ff98bf 100644
--- a/ggml/include/ggml-backend.h --- a/ggml/include/ggml-backend.h
+++ b/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h
@@ -318,6 +318,7 @@ extern "C" { @@ -318,6 +318,7 @@ extern "C" {
@ -35,7 +35,7 @@ index f1b74078..c54ff98b 100644
GGML_API void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend); GGML_API void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend);
GGML_API ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node); GGML_API ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node);
diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c
index 929bc448..eee9d3b1 100644 index c830c0965..363853873 100644
--- a/ggml/src/ggml-alloc.c --- a/ggml/src/ggml-alloc.c
+++ b/ggml/src/ggml-alloc.c +++ b/ggml/src/ggml-alloc.c
@@ -486,6 +486,7 @@ struct node_alloc { @@ -486,6 +486,7 @@ struct node_alloc {
@ -64,7 +64,7 @@ index 929bc448..eee9d3b1 100644
free(galloc->buffers); free(galloc->buffers);
free(galloc->buf_tallocs); free(galloc->buf_tallocs);
free(galloc->node_allocs); free(galloc->node_allocs);
@@ -869,6 +874,8 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c @@ -891,6 +896,8 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c
} }
} }
@ -73,7 +73,7 @@ index 929bc448..eee9d3b1 100644
// reallocate buffers if needed // reallocate buffers if needed
for (int i = 0; i < galloc->n_buffers; i++) { for (int i = 0; i < galloc->n_buffers; i++) {
// if the buffer type is used multiple times, we reuse the same buffer // if the buffer type is used multiple times, we reuse the same buffer
@@ -898,14 +905,19 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c @@ -920,14 +927,19 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c
ggml_vbuffer_free(galloc->buffers[i]); ggml_vbuffer_free(galloc->buffers[i]);
galloc->buffers[i] = ggml_vbuffer_alloc(galloc->bufts[i], galloc->buf_tallocs[i], GGML_BACKEND_BUFFER_USAGE_COMPUTE); galloc->buffers[i] = ggml_vbuffer_alloc(galloc->bufts[i], galloc->buf_tallocs[i], GGML_BACKEND_BUFFER_USAGE_COMPUTE);
@ -96,7 +96,7 @@ index 929bc448..eee9d3b1 100644
} }
bool ggml_gallocr_reserve(ggml_gallocr_t galloc, struct ggml_cgraph *graph) { bool ggml_gallocr_reserve(ggml_gallocr_t galloc, struct ggml_cgraph *graph) {
@@ -1060,6 +1072,22 @@ size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_id) { @@ -1082,6 +1094,22 @@ size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_id) {
return ggml_vbuffer_size(galloc->buffers[buffer_id]); return ggml_vbuffer_size(galloc->buffers[buffer_id]);
} }
@ -120,7 +120,7 @@ index 929bc448..eee9d3b1 100644
static void free_buffers(ggml_backend_buffer_t ** buffers, const size_t * n_buffers) { static void free_buffers(ggml_backend_buffer_t ** buffers, const size_t * n_buffers) {
diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp
index 8ba86f82..cb2b9956 100644 index 8ba86f824..cb2b99562 100644
--- a/ggml/src/ggml-backend.cpp --- a/ggml/src/ggml-backend.cpp
+++ b/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp
@@ -1809,6 +1809,13 @@ size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backe @@ -1809,6 +1809,13 @@ size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backe

View File

@ -12,7 +12,7 @@ with tools (e.g. nvidia-smi) and system management libraries (e.g. nvml).
3 files changed, 63 insertions(+), 6 deletions(-) 3 files changed, 63 insertions(+), 6 deletions(-)
diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h
index c54ff98b..229bf387 100644 index c54ff98bf..229bf387b 100644
--- a/ggml/include/ggml-backend.h --- a/ggml/include/ggml-backend.h
+++ b/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h
@@ -158,6 +158,7 @@ extern "C" { @@ -158,6 +158,7 @@ extern "C" {
@ -24,7 +24,7 @@ index c54ff98b..229bf387 100644
size_t memory_total; size_t memory_total;
// device type // device type
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index c0b1e4c1..5b852f69 100644 index aefc6935e..cc201afff 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu --- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -183,6 +183,51 @@ static int ggml_cuda_parse_id(char devName[]) { @@ -183,6 +183,51 @@ static int ggml_cuda_parse_id(char devName[]) {
@ -110,7 +110,7 @@ index c0b1e4c1..5b852f69 100644
std::string device_name(prop.name); std::string device_name(prop.name);
if (device_name == "NVIDIA GeForce MX450") { if (device_name == "NVIDIA GeForce MX450") {
turing_devices_without_mma.push_back({ id, device_name }); turing_devices_without_mma.push_back({ id, device_name });
@@ -3276,6 +3323,7 @@ struct ggml_backend_cuda_device_context { @@ -3268,6 +3315,7 @@ struct ggml_backend_cuda_device_context {
std::string name; std::string name;
std::string description; std::string description;
std::string pci_bus_id; std::string pci_bus_id;
@ -118,7 +118,7 @@ index c0b1e4c1..5b852f69 100644
}; };
static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) { static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) {
@@ -3288,6 +3336,11 @@ static const char * ggml_backend_cuda_device_get_description(ggml_backend_dev_t @@ -3280,6 +3328,11 @@ static const char * ggml_backend_cuda_device_get_description(ggml_backend_dev_t
return ctx->description.c_str(); return ctx->description.c_str();
} }
@ -130,7 +130,7 @@ index c0b1e4c1..5b852f69 100644
static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
ggml_cuda_set_device(ctx->device); ggml_cuda_set_device(ctx->device);
@@ -3304,6 +3357,7 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back @@ -3296,6 +3349,7 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back
props->name = ggml_backend_cuda_device_get_name(dev); props->name = ggml_backend_cuda_device_get_name(dev);
props->description = ggml_backend_cuda_device_get_description(dev); props->description = ggml_backend_cuda_device_get_description(dev);
@ -138,7 +138,7 @@ index c0b1e4c1..5b852f69 100644
props->type = ggml_backend_cuda_device_get_type(dev); props->type = ggml_backend_cuda_device_get_type(dev);
props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str(); props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str();
ggml_backend_cuda_device_get_memory(dev, &props->memory_free, &props->memory_total); ggml_backend_cuda_device_get_memory(dev, &props->memory_free, &props->memory_total);
@@ -3873,6 +3927,7 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { @@ -3869,6 +3923,7 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
cudaDeviceProp prop; cudaDeviceProp prop;
CUDA_CHECK(cudaGetDeviceProperties(&prop, i)); CUDA_CHECK(cudaGetDeviceProperties(&prop, i));
dev_ctx->description = prop.name; dev_ctx->description = prop.name;
@ -147,7 +147,7 @@ index c0b1e4c1..5b852f69 100644
char pci_bus_id[16] = {}; char pci_bus_id[16] = {};
snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.0", prop.pciDomainID, prop.pciBusID, prop.pciDeviceID); snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.0", prop.pciDomainID, prop.pciBusID, prop.pciDeviceID);
diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp
index bf096227..f2ff9f32 100644 index bf0962274..f2ff9f322 100644
--- a/ggml/src/ggml-metal/ggml-metal.cpp --- a/ggml/src/ggml-metal/ggml-metal.cpp
+++ b/ggml/src/ggml-metal/ggml-metal.cpp +++ b/ggml/src/ggml-metal/ggml-metal.cpp
@@ -538,6 +538,7 @@ static enum ggml_backend_dev_type ggml_backend_metal_device_get_type(ggml_backen @@ -538,6 +538,7 @@ static enum ggml_backend_dev_type ggml_backend_metal_device_get_type(ggml_backen

View File

@ -10,7 +10,7 @@ Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
2 files changed, 13 insertions(+) 2 files changed, 13 insertions(+)
diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp
index 4d487581..35a0d25e 100644 index 4d487581a..35a0d25ed 100644
--- a/tools/mtmd/mtmd.cpp --- a/tools/mtmd/mtmd.cpp
+++ b/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp
@@ -79,6 +79,16 @@ enum mtmd_slice_tmpl { @@ -79,6 +79,16 @@ enum mtmd_slice_tmpl {
@ -31,7 +31,7 @@ index 4d487581..35a0d25e 100644
return "<__media__>"; return "<__media__>";
} }
diff --git a/tools/mtmd/mtmd.h b/tools/mtmd/mtmd.h diff --git a/tools/mtmd/mtmd.h b/tools/mtmd/mtmd.h
index f4ea07d3..cf287224 100644 index f4ea07d3a..cf287224b 100644
--- a/tools/mtmd/mtmd.h --- a/tools/mtmd/mtmd.h
+++ b/tools/mtmd/mtmd.h +++ b/tools/mtmd/mtmd.h
@@ -75,6 +75,9 @@ typedef struct mtmd_input_chunk mtmd_input_chunk; @@ -75,6 +75,9 @@ typedef struct mtmd_input_chunk mtmd_input_chunk;

View File

@ -8,10 +8,10 @@ Subject: [PATCH] no power throttling win32 with gnuc
1 file changed, 1 insertion(+), 1 deletion(-) 1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c
index 99509b0c..b13a491d 100644 index 4b2f8b7bd..046646282 100644
--- a/ggml/src/ggml-cpu/ggml-cpu.c --- a/ggml/src/ggml-cpu/ggml-cpu.c
+++ b/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c
@@ -2437,7 +2437,7 @@ static bool ggml_thread_apply_priority(int32_t prio) { @@ -2441,7 +2441,7 @@ static bool ggml_thread_apply_priority(int32_t prio) {
// Newer Windows 11 versions aggresively park (offline) CPU cores and often place // Newer Windows 11 versions aggresively park (offline) CPU cores and often place
// all our threads onto the first 4 cores which results in terrible performance with // all our threads onto the first 4 cores which results in terrible performance with
// n_threads > 4 // n_threads > 4

View File

@ -9,7 +9,7 @@ Only enable BF16 on supported MacOS versions (v14+)
1 file changed, 6 insertions(+), 1 deletion(-) 1 file changed, 6 insertions(+), 1 deletion(-)
diff --git a/ggml/src/ggml-metal/ggml-metal-context.m b/ggml/src/ggml-metal/ggml-metal-context.m diff --git a/ggml/src/ggml-metal/ggml-metal-context.m b/ggml/src/ggml-metal/ggml-metal-context.m
index 052efb7a..b47dc787 100644 index 052efb7ac..b47dc7879 100644
--- a/ggml/src/ggml-metal/ggml-metal-context.m --- a/ggml/src/ggml-metal/ggml-metal-context.m
+++ b/ggml/src/ggml-metal/ggml-metal-context.m +++ b/ggml/src/ggml-metal/ggml-metal-context.m
@@ -125,7 +125,12 @@ ggml_metal_t ggml_metal_init(ggml_metal_device_t dev) { @@ -125,7 +125,12 @@ ggml_metal_t ggml_metal_init(ggml_metal_device_t dev) {

View File

@ -178,19 +178,19 @@ index 3191faaa4..32f14c811 100644
static const struct ggml_backend_i ggml_backend_cpu_i = { static const struct ggml_backend_i ggml_backend_cpu_i = {
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index 5b852f690..c555cd30f 100644 index cc201afff..02d413467 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu --- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -2684,7 +2684,7 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) { @@ -2693,7 +2693,7 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
#ifdef USE_CUDA_GRAPH #ifdef USE_CUDA_GRAPH
static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, static bool check_node_graph_compatibility(ggml_cgraph * cgraph,
- bool use_cuda_graph) { - bool use_cuda_graph) {
+ int batch_size, bool use_cuda_graph) { + int batch_size, bool use_cuda_graph) {
// Loop over nodes in GGML graph to obtain info needed for CUDA graph // Loop over nodes in GGML graph to obtain info needed for CUDA graph
cuda_ctx->cuda_graph->cpy_dest_ptrs.clear();
@@ -2718,24 +2718,34 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud @@ -2726,24 +2726,34 @@ static bool check_node_graph_compatibility(ggml_cgraph * cgraph,
#endif #endif
} }
@ -240,8 +240,8 @@ index 5b852f690..c555cd30f 100644
+ } + }
} }
if (node->op == GGML_OP_CPY) { if (!use_cuda_graph) {
@@ -3132,7 +3142,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx @@ -3128,7 +3138,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
} }
} }
@ -250,12 +250,12 @@ index 5b852f690..c555cd30f 100644
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
ggml_cuda_set_device(cuda_ctx->device); ggml_cuda_set_device(cuda_ctx->device);
@@ -3170,7 +3180,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, @@ -3166,7 +3176,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
if (use_cuda_graph) { if (use_cuda_graph) {
cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph); cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph);
- use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops(cuda_ctx, cgraph, use_cuda_graph); - use_cuda_graph = check_node_graph_compatibility(cgraph, use_cuda_graph);
+ use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops(cuda_ctx, cgraph, batch_size, use_cuda_graph); + use_cuda_graph = check_node_graph_compatibility(cgraph, batch_size, use_cuda_graph);
// Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates. // Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
if (use_cuda_graph && cuda_graph_update_required) { if (use_cuda_graph && cuda_graph_update_required) {
@ -278,10 +278,10 @@ index f2ff9f322..05ff6a5a6 100644
static void ggml_backend_metal_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) { static void ggml_backend_metal_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) {
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index ed83236f4..bd3ece516 100644 index 216dc167c..3a6bbe564 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -12015,7 +12015,7 @@ static uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const stru @@ -12357,7 +12357,7 @@ static uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const stru
return num_adds; return num_adds;
} }
@ -290,7 +290,7 @@ index ed83236f4..bd3ece516 100644
VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)"); VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)");
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
@@ -12211,6 +12211,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg @@ -12561,6 +12561,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
return GGML_STATUS_SUCCESS; return GGML_STATUS_SUCCESS;
UNUSED(backend); UNUSED(backend);

View File

@ -8,10 +8,10 @@ Subject: [PATCH] Disable ggml-blas on macos v13 and older
1 file changed, 5 insertions(+) 1 file changed, 5 insertions(+)
diff --git a/ggml/src/ggml-blas/ggml-blas.cpp b/ggml/src/ggml-blas/ggml-blas.cpp diff --git a/ggml/src/ggml-blas/ggml-blas.cpp b/ggml/src/ggml-blas/ggml-blas.cpp
index 5b888cdd..2a9ff7f6 100644 index 88d088952..6a38a51a2 100644
--- a/ggml/src/ggml-blas/ggml-blas.cpp --- a/ggml/src/ggml-blas/ggml-blas.cpp
+++ b/ggml/src/ggml-blas/ggml-blas.cpp +++ b/ggml/src/ggml-blas/ggml-blas.cpp
@@ -506,6 +506,11 @@ static const struct ggml_backend_reg_i ggml_backend_blas_reg_i = { @@ -507,6 +507,11 @@ static const struct ggml_backend_reg_i ggml_backend_blas_reg_i = {
}; };
ggml_backend_reg_t ggml_backend_blas_reg(void) { ggml_backend_reg_t ggml_backend_blas_reg(void) {

View File

@ -8,7 +8,7 @@ Subject: [PATCH] fix mtmd-audio.cpp build on windows
1 file changed, 1 insertion(+), 1 deletion(-) 1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/tools/mtmd/mtmd-audio.cpp b/tools/mtmd/mtmd-audio.cpp diff --git a/tools/mtmd/mtmd-audio.cpp b/tools/mtmd/mtmd-audio.cpp
index 4d053895..84bdc277 100644 index 4d053895c..84bdc2777 100644
--- a/tools/mtmd/mtmd-audio.cpp --- a/tools/mtmd/mtmd-audio.cpp
+++ b/tools/mtmd/mtmd-audio.cpp +++ b/tools/mtmd/mtmd-audio.cpp
@@ -1,6 +1,6 @@ @@ -1,6 +1,6 @@

View File

@ -219,7 +219,7 @@ index 41eef3b5f..c81a2e48a 100644
void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend) { void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend) {
diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh
index e0abde542..e98044bd8 100644 index 41ff89c4d..2931c15ca 100644
--- a/ggml/src/ggml-cuda/common.cuh --- a/ggml/src/ggml-cuda/common.cuh
+++ b/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh
@@ -35,6 +35,41 @@ @@ -35,6 +35,41 @@
@ -274,7 +274,7 @@ index e0abde542..e98044bd8 100644
}; };
template<typename T> template<typename T>
@@ -999,11 +1037,11 @@ struct ggml_backend_cuda_context { @@ -992,11 +1030,11 @@ struct ggml_backend_cuda_context {
// pool // pool
std::unique_ptr<ggml_cuda_pool> pools[GGML_CUDA_MAX_DEVICES]; std::unique_ptr<ggml_cuda_pool> pools[GGML_CUDA_MAX_DEVICES];
@ -288,7 +288,7 @@ index e0abde542..e98044bd8 100644
} }
return *pools[device]; return *pools[device];
} }
@@ -1011,4 +1049,20 @@ struct ggml_backend_cuda_context { @@ -1004,4 +1042,20 @@ struct ggml_backend_cuda_context {
ggml_cuda_pool & pool() { ggml_cuda_pool & pool() {
return pool(device); return pool(device);
} }
@ -310,10 +310,10 @@ index e0abde542..e98044bd8 100644
+ } + }
}; };
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index c555cd30f..eb3db0f19 100644 index 02d413467..f79e5d65c 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu --- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -350,6 +350,8 @@ const ggml_cuda_device_info & ggml_cuda_info() { @@ -359,6 +359,8 @@ const ggml_cuda_device_info & ggml_cuda_info() {
// #define DEBUG_CUDA_MALLOC // #define DEBUG_CUDA_MALLOC
@ -322,7 +322,7 @@ index c555cd30f..eb3db0f19 100644
// buffer pool for cuda (legacy) // buffer pool for cuda (legacy)
struct ggml_cuda_pool_leg : public ggml_cuda_pool { struct ggml_cuda_pool_leg : public ggml_cuda_pool {
static const int MAX_BUFFERS = 256; static const int MAX_BUFFERS = 256;
@@ -362,9 +364,12 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { @@ -371,9 +373,12 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
ggml_cuda_buffer buffer_pool[MAX_BUFFERS] = {}; ggml_cuda_buffer buffer_pool[MAX_BUFFERS] = {};
size_t pool_size = 0; size_t pool_size = 0;
@ -337,7 +337,7 @@ index c555cd30f..eb3db0f19 100644
} }
~ggml_cuda_pool_leg() { ~ggml_cuda_pool_leg() {
@@ -372,7 +377,9 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { @@ -381,7 +386,9 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
for (int i = 0; i < MAX_BUFFERS; ++i) { for (int i = 0; i < MAX_BUFFERS; ++i) {
ggml_cuda_buffer & b = buffer_pool[i]; ggml_cuda_buffer & b = buffer_pool[i];
if (b.ptr != nullptr) { if (b.ptr != nullptr) {
@ -348,7 +348,7 @@ index c555cd30f..eb3db0f19 100644
pool_size -= b.size; pool_size -= b.size;
} }
} }
@@ -420,8 +427,15 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { @@ -429,8 +436,15 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
void * ptr; void * ptr;
size_t look_ahead_size = (size_t) (1.05 * size); size_t look_ahead_size = (size_t) (1.05 * size);
look_ahead_size = 256 * ((look_ahead_size + 255)/256); look_ahead_size = 256 * ((look_ahead_size + 255)/256);
@ -366,7 +366,7 @@ index c555cd30f..eb3db0f19 100644
*actual_size = look_ahead_size; *actual_size = look_ahead_size;
pool_size += look_ahead_size; pool_size += look_ahead_size;
#ifdef DEBUG_CUDA_MALLOC #ifdef DEBUG_CUDA_MALLOC
@@ -441,10 +455,20 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { @@ -450,10 +464,20 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
} }
} }
GGML_LOG_DEBUG(GGML_CUDA_NAME " buffer pool full, increase MAX_CUDA_BUFFERS\n"); GGML_LOG_DEBUG(GGML_CUDA_NAME " buffer pool full, increase MAX_CUDA_BUFFERS\n");
@ -389,7 +389,7 @@ index c555cd30f..eb3db0f19 100644
}; };
// pool with virtual memory // pool with virtual memory
@@ -456,18 +480,24 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { @@ -465,18 +489,24 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
CUdeviceptr pool_addr = 0; CUdeviceptr pool_addr = 0;
size_t pool_used = 0; size_t pool_used = 0;
size_t pool_size = 0; size_t pool_size = 0;
@ -417,7 +417,7 @@ index c555cd30f..eb3db0f19 100644
#if defined(GGML_USE_HIP) #if defined(GGML_USE_HIP)
// Workaround for https://github.com/ROCm/ROCR-Runtime/issues/285 // Workaround for https://github.com/ROCm/ROCR-Runtime/issues/285
for (std::pair<CUdeviceptr, size_t> & mapping : mappings) { for (std::pair<CUdeviceptr, size_t> & mapping : mappings) {
@@ -494,35 +524,49 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { @@ -503,35 +533,49 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
GGML_ASSERT(pool_size + reserve_size <= CUDA_POOL_VMM_MAX_SIZE); GGML_ASSERT(pool_size + reserve_size <= CUDA_POOL_VMM_MAX_SIZE);
@ -493,7 +493,7 @@ index c555cd30f..eb3db0f19 100644
// add to the pool // add to the pool
pool_size += reserve_size; pool_size += reserve_size;
@@ -555,16 +599,24 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { @@ -564,16 +608,24 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
// all deallocations must be in reverse order of the allocations // all deallocations must be in reverse order of the allocations
GGML_ASSERT(ptr == (void *) ((char *)(pool_addr) + pool_used)); GGML_ASSERT(ptr == (void *) ((char *)(pool_addr) + pool_used));
} }
@ -521,7 +521,7 @@ index c555cd30f..eb3db0f19 100644
} }
// destroying a cuBLAS handle while a graph is being captured in a different thread can result in a CUDA error // destroying a cuBLAS handle while a graph is being captured in a different thread can result in a CUDA error
@@ -748,11 +800,20 @@ static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffer(ggml_bac @@ -757,11 +809,20 @@ static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffer(ggml_bac
} }
static size_t ggml_backend_cuda_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { static size_t ggml_backend_cuda_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
@ -543,7 +543,7 @@ index c555cd30f..eb3db0f19 100644
static size_t ggml_backend_cuda_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { static size_t ggml_backend_cuda_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
size_t size = ggml_nbytes(tensor); size_t size = ggml_nbytes(tensor);
int64_t ne0 = tensor->ne[0]; int64_t ne0 = tensor->ne[0];
@@ -776,6 +837,7 @@ static const ggml_backend_buffer_type_i ggml_backend_cuda_buffer_type_interface @@ -785,6 +846,7 @@ static const ggml_backend_buffer_type_i ggml_backend_cuda_buffer_type_interface
/* .get_max_size = */ NULL, // defaults to SIZE_MAX /* .get_max_size = */ NULL, // defaults to SIZE_MAX
/* .get_alloc_size = */ ggml_backend_cuda_buffer_type_get_alloc_size, /* .get_alloc_size = */ ggml_backend_cuda_buffer_type_get_alloc_size,
/* .is_host = */ NULL, /* .is_host = */ NULL,
@ -551,7 +551,7 @@ index c555cd30f..eb3db0f19 100644
}; };
ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device) { ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device) {
@@ -3003,6 +3065,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, @@ -2986,6 +3048,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) { bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) {
@ -559,7 +559,7 @@ index c555cd30f..eb3db0f19 100644
// flag used to determine whether it is an integrated_gpu // flag used to determine whether it is an integrated_gpu
const bool integrated = ggml_cuda_info().devices[cuda_ctx->device].integrated; const bool integrated = ggml_cuda_info().devices[cuda_ctx->device].integrated;
@@ -3018,6 +3081,11 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx @@ -3001,6 +3064,11 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
continue; continue;
} }
@ -571,7 +571,7 @@ index c555cd30f..eb3db0f19 100644
static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr); static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
if (!disable_fusion) { if (!disable_fusion) {
@@ -3144,6 +3212,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx @@ -3140,6 +3208,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph, int batch_size) { static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph, int batch_size) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
@ -579,7 +579,7 @@ index c555cd30f..eb3db0f19 100644
ggml_cuda_set_device(cuda_ctx->device); ggml_cuda_set_device(cuda_ctx->device);
@@ -3223,6 +3292,71 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, @@ -3215,6 +3284,71 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
return GGML_STATUS_SUCCESS; return GGML_STATUS_SUCCESS;
} }
@ -651,7 +651,7 @@ index c555cd30f..eb3db0f19 100644
static void ggml_backend_cuda_event_record(ggml_backend_t backend, ggml_backend_event_t event) { static void ggml_backend_cuda_event_record(ggml_backend_t backend, ggml_backend_event_t event) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
@@ -3263,6 +3397,9 @@ static const ggml_backend_i ggml_backend_cuda_interface = { @@ -3255,6 +3389,9 @@ static const ggml_backend_i ggml_backend_cuda_interface = {
/* .event_record = */ ggml_backend_cuda_event_record, /* .event_record = */ ggml_backend_cuda_event_record,
/* .event_wait = */ ggml_backend_cuda_event_wait, /* .event_wait = */ ggml_backend_cuda_event_wait,
/* .graph_optimize = */ NULL, /* .graph_optimize = */ NULL,

View File

@ -8,7 +8,7 @@ Subject: [PATCH] decode: disable output_all
1 file changed, 1 insertion(+), 2 deletions(-) 1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/src/llama-context.cpp b/src/llama-context.cpp diff --git a/src/llama-context.cpp b/src/llama-context.cpp
index e7526e7d..53a5e3a9 100644 index bd348bcad..8b4a89d38 100644
--- a/src/llama-context.cpp --- a/src/llama-context.cpp
+++ b/src/llama-context.cpp +++ b/src/llama-context.cpp
@@ -974,8 +974,7 @@ int llama_context::decode(const llama_batch & batch_inp) { @@ -974,8 +974,7 @@ int llama_context::decode(const llama_batch & batch_inp) {

View File

@ -16,7 +16,7 @@ unused then it can be reset to free these data structures.
6 files changed, 32 insertions(+), 2 deletions(-) 6 files changed, 32 insertions(+), 2 deletions(-)
diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h
index 1ff53ed03..ba181d09d 100644 index b3b5b356a..69223c488 100644
--- a/ggml/include/ggml-backend.h --- a/ggml/include/ggml-backend.h
+++ b/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h
@@ -178,6 +178,7 @@ extern "C" { @@ -178,6 +178,7 @@ extern "C" {
@ -28,7 +28,7 @@ index 1ff53ed03..ba181d09d 100644
GGML_API ggml_backend_buffer_type_t ggml_backend_dev_host_buffer_type(ggml_backend_dev_t device); GGML_API ggml_backend_buffer_type_t ggml_backend_dev_host_buffer_type(ggml_backend_dev_t device);
GGML_API ggml_backend_buffer_t ggml_backend_dev_buffer_from_host_ptr(ggml_backend_dev_t device, void * ptr, size_t size, size_t max_tensor_size); GGML_API ggml_backend_buffer_t ggml_backend_dev_buffer_from_host_ptr(ggml_backend_dev_t device, void * ptr, size_t size, size_t max_tensor_size);
diff --git a/ggml/src/ggml-backend-impl.h b/ggml/src/ggml-backend-impl.h diff --git a/ggml/src/ggml-backend-impl.h b/ggml/src/ggml-backend-impl.h
index 3c3f22fc0..43c91d9f2 100644 index 7bdf9d81f..21b35ac5c 100644
--- a/ggml/src/ggml-backend-impl.h --- a/ggml/src/ggml-backend-impl.h
+++ b/ggml/src/ggml-backend-impl.h +++ b/ggml/src/ggml-backend-impl.h
@@ -195,6 +195,10 @@ extern "C" { @@ -195,6 +195,10 @@ extern "C" {
@ -43,7 +43,7 @@ index 3c3f22fc0..43c91d9f2 100644
struct ggml_backend_device { struct ggml_backend_device {
diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp
index 6ef5eeafa..0b757af59 100644 index c81a2e48a..9b0a9b91f 100644
--- a/ggml/src/ggml-backend.cpp --- a/ggml/src/ggml-backend.cpp
+++ b/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp
@@ -526,6 +526,14 @@ ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * par @@ -526,6 +526,14 @@ ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * par
@ -62,7 +62,7 @@ index 6ef5eeafa..0b757af59 100644
GGML_ASSERT(device); GGML_ASSERT(device);
return device->iface.get_buffer_type(device); return device->iface.get_buffer_type(device);
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index 811462c79..87c6c34a4 100644 index f79e5d65c..c9333689f 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu --- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -107,6 +107,11 @@ int ggml_cuda_get_device() { @@ -107,6 +107,11 @@ int ggml_cuda_get_device() {
@ -77,7 +77,7 @@ index 811462c79..87c6c34a4 100644
static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) { static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {
ggml_cuda_set_device(device); ggml_cuda_set_device(device);
cudaError_t err; cudaError_t err;
@@ -3515,7 +3520,10 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back @@ -3499,7 +3504,10 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back
props->id = ggml_backend_cuda_device_get_id(dev); props->id = ggml_backend_cuda_device_get_id(dev);
props->type = ggml_backend_cuda_device_get_type(dev); props->type = ggml_backend_cuda_device_get_type(dev);
props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str(); props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str();
@ -89,7 +89,7 @@ index 811462c79..87c6c34a4 100644
bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr; bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr;
#ifdef GGML_CUDA_NO_PEER_COPY #ifdef GGML_CUDA_NO_PEER_COPY
@@ -3948,6 +3956,11 @@ static void ggml_backend_cuda_device_event_synchronize(ggml_backend_dev_t dev, g @@ -3936,6 +3944,11 @@ static void ggml_backend_cuda_device_event_synchronize(ggml_backend_dev_t dev, g
CUDA_CHECK(cudaEventSynchronize((cudaEvent_t)event->context)); CUDA_CHECK(cudaEventSynchronize((cudaEvent_t)event->context));
} }
@ -101,7 +101,7 @@ index 811462c79..87c6c34a4 100644
static const ggml_backend_device_i ggml_backend_cuda_device_interface = { static const ggml_backend_device_i ggml_backend_cuda_device_interface = {
/* .get_name = */ ggml_backend_cuda_device_get_name, /* .get_name = */ ggml_backend_cuda_device_get_name,
/* .get_description = */ ggml_backend_cuda_device_get_description, /* .get_description = */ ggml_backend_cuda_device_get_description,
@@ -3964,6 +3977,7 @@ static const ggml_backend_device_i ggml_backend_cuda_device_interface = { @@ -3952,6 +3965,7 @@ static const ggml_backend_device_i ggml_backend_cuda_device_interface = {
/* .event_new = */ ggml_backend_cuda_device_event_new, /* .event_new = */ ggml_backend_cuda_device_event_new,
/* .event_free = */ ggml_backend_cuda_device_event_free, /* .event_free = */ ggml_backend_cuda_device_event_free,
/* .event_synchronize = */ ggml_backend_cuda_device_event_synchronize, /* .event_synchronize = */ ggml_backend_cuda_device_event_synchronize,
@ -122,10 +122,10 @@ index 890c10364..1f06be80e 100644
#define cudaError_t hipError_t #define cudaError_t hipError_t
#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled #define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled
diff --git a/src/llama.cpp b/src/llama.cpp diff --git a/src/llama.cpp b/src/llama.cpp
index fe5a7a835..d821a96a0 100644 index ab2e9868a..74c49e651 100644
--- a/src/llama.cpp --- a/src/llama.cpp
+++ b/src/llama.cpp +++ b/src/llama.cpp
@@ -267,10 +267,12 @@ static struct llama_model * llama_model_load_from_file_impl( @@ -270,10 +270,12 @@ static struct llama_model * llama_model_load_from_file_impl(
for (auto * dev : model->devices) { for (auto * dev : model->devices) {
ggml_backend_dev_props props; ggml_backend_dev_props props;
ggml_backend_dev_get_props(dev, &props); ggml_backend_dev_get_props(dev, &props);

View File

@ -8,7 +8,7 @@ Subject: [PATCH] harden uncaught exception registration
1 file changed, 6 insertions(+), 2 deletions(-) 1 file changed, 6 insertions(+), 2 deletions(-)
diff --git a/ggml/src/ggml.cpp b/ggml/src/ggml.cpp diff --git a/ggml/src/ggml.cpp b/ggml/src/ggml.cpp
index 0d388d45..f5bcb446 100644 index 0d388d455..f5bcb446d 100644
--- a/ggml/src/ggml.cpp --- a/ggml/src/ggml.cpp
+++ b/ggml/src/ggml.cpp +++ b/ggml/src/ggml.cpp
@@ -19,8 +19,12 @@ static bool ggml_uncaught_exception_init = []{ @@ -19,8 +19,12 @@ static bool ggml_uncaught_exception_init = []{

View File

@ -45,7 +45,7 @@ index 69223c488..6510e0cba 100644
GGML_API const char * ggml_backend_dev_name(ggml_backend_dev_t device); GGML_API const char * ggml_backend_dev_name(ggml_backend_dev_t device);
diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt
index 0609c6503..aefe43bdd 100644 index f9a6587f1..03f359ae9 100644
--- a/ggml/src/CMakeLists.txt --- a/ggml/src/CMakeLists.txt
+++ b/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt
@@ -209,6 +209,8 @@ add_library(ggml-base @@ -209,6 +209,8 @@ add_library(ggml-base
@ -58,7 +58,7 @@ index 0609c6503..aefe43bdd 100644
target_include_directories(ggml-base PRIVATE .) target_include_directories(ggml-base PRIVATE .)
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index 5787e8cd5..d232bf828 100644 index c9333689f..41b00af83 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu --- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -261,6 +261,16 @@ static ggml_cuda_device_info ggml_cuda_init() { @@ -261,6 +261,16 @@ static ggml_cuda_device_info ggml_cuda_init() {
@ -90,7 +90,7 @@ index 5787e8cd5..d232bf828 100644
GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s, ID: %s\n", GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s, ID: %s\n",
id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no", id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no",
ggml_cuda_parse_uuid(prop, id).c_str()); ggml_cuda_parse_uuid(prop, id).c_str());
@@ -3476,6 +3491,11 @@ struct ggml_backend_cuda_device_context { @@ -3468,6 +3483,11 @@ struct ggml_backend_cuda_device_context {
std::string description; std::string description;
std::string pci_bus_id; std::string pci_bus_id;
std::string id; std::string id;
@ -102,7 +102,7 @@ index 5787e8cd5..d232bf828 100644
}; };
static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) { static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) {
@@ -3496,6 +3516,28 @@ static const char * ggml_backend_cuda_device_get_id(ggml_backend_dev_t dev) { @@ -3488,6 +3508,28 @@ static const char * ggml_backend_cuda_device_get_id(ggml_backend_dev_t dev) {
static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
ggml_cuda_set_device(ctx->device); ggml_cuda_set_device(ctx->device);
@ -131,7 +131,7 @@ index 5787e8cd5..d232bf828 100644
CUDA_CHECK(cudaMemGetInfo(free, total)); CUDA_CHECK(cudaMemGetInfo(free, total));
} }
@@ -3504,6 +3546,7 @@ static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend @@ -3496,6 +3538,7 @@ static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend
return GGML_BACKEND_DEVICE_TYPE_GPU; return GGML_BACKEND_DEVICE_TYPE_GPU;
} }
@ -139,7 +139,7 @@ index 5787e8cd5..d232bf828 100644
static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) { static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
@@ -3517,6 +3560,19 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back @@ -3509,6 +3552,19 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back
// If you need the memory data, call ggml_backend_dev_memory() explicitly. // If you need the memory data, call ggml_backend_dev_memory() explicitly.
props->memory_total = props->memory_free = 0; props->memory_total = props->memory_free = 0;
@ -159,7 +159,7 @@ index 5787e8cd5..d232bf828 100644
bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr; bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr;
#ifdef GGML_CUDA_NO_PEER_COPY #ifdef GGML_CUDA_NO_PEER_COPY
bool events = false; bool events = false;
@@ -4079,6 +4135,7 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { @@ -4075,6 +4131,7 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
std::lock_guard<std::mutex> lock(mutex); std::lock_guard<std::mutex> lock(mutex);
if (!initialized) { if (!initialized) {
ggml_backend_cuda_reg_context * ctx = new ggml_backend_cuda_reg_context; ggml_backend_cuda_reg_context * ctx = new ggml_backend_cuda_reg_context;
@ -167,7 +167,7 @@ index 5787e8cd5..d232bf828 100644
for (int i = 0; i < ggml_cuda_info().device_count; i++) { for (int i = 0; i < ggml_cuda_info().device_count; i++) {
ggml_backend_cuda_device_context * dev_ctx = new ggml_backend_cuda_device_context; ggml_backend_cuda_device_context * dev_ctx = new ggml_backend_cuda_device_context;
@@ -4094,6 +4151,14 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { @@ -4090,6 +4147,14 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.0", prop.pciDomainID, prop.pciBusID, prop.pciDeviceID); snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.0", prop.pciDomainID, prop.pciBusID, prop.pciDeviceID);
dev_ctx->pci_bus_id = pci_bus_id; dev_ctx->pci_bus_id = pci_bus_id;
@ -204,11 +204,11 @@ index 1f06be80e..2f9ef2dc0 100644
#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled #define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled
#define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled #define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled
diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h
index d0fb3bcca..b63edd0c1 100644 index e9201cdc6..44ae76d66 100644
--- a/ggml/src/ggml-impl.h --- a/ggml/src/ggml-impl.h
+++ b/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h
@@ -638,6 +638,14 @@ static inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx @@ -677,6 +677,14 @@ static inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph,
return ggml_can_fuse_ext(cgraph, idxs, ops, num_ops); return ggml_can_fuse_subgraph_ext(cgraph, idxs, count, ops, outputs, num_outputs);
} }
+// Management libraries for fetching more accurate free VRAM data +// Management libraries for fetching more accurate free VRAM data
@ -243,10 +243,10 @@ index 05ff6a5a6..032dee76d 100644
/* .async = */ true, /* .async = */ true,
/* .host_buffer = */ false, /* .host_buffer = */ false,
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index bd3ece516..7cfb14a54 100644 index 3a6bbe564..d2c278a35 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -231,6 +231,7 @@ class vk_memory_logger; @@ -229,6 +229,7 @@ class vk_memory_logger;
#endif #endif
class vk_perf_logger; class vk_perf_logger;
static void ggml_vk_destroy_buffer(vk_buffer& buf); static void ggml_vk_destroy_buffer(vk_buffer& buf);
@ -254,7 +254,7 @@ index bd3ece516..7cfb14a54 100644
static constexpr uint32_t mul_mat_vec_max_cols = 8; static constexpr uint32_t mul_mat_vec_max_cols = 8;
static constexpr uint32_t p021_max_gqa_ratio = 8; static constexpr uint32_t p021_max_gqa_ratio = 8;
@@ -11585,6 +11586,29 @@ static void ggml_vk_get_device_description(int device, char * description, size_ @@ -11813,6 +11814,29 @@ static void ggml_vk_get_device_description(int device, char * description, size_
snprintf(description, description_size, "%s", props.deviceName.data()); snprintf(description, description_size, "%s", props.deviceName.data());
} }
@ -284,7 +284,7 @@ index bd3ece516..7cfb14a54 100644
// backend interface // backend interface
#define UNUSED GGML_UNUSED #define UNUSED GGML_UNUSED
@@ -12392,31 +12416,102 @@ void ggml_backend_vk_get_device_description(int device, char * description, size @@ -12761,31 +12785,102 @@ void ggml_backend_vk_get_device_description(int device, char * description, size
ggml_vk_get_device_description(dev_idx, description, description_size); ggml_vk_get_device_description(dev_idx, description, description_size);
} }
@ -404,7 +404,7 @@ index bd3ece516..7cfb14a54 100644
break; break;
} }
} }
@@ -12449,8 +12544,13 @@ static std::string ggml_backend_vk_get_device_pci_id(int device_idx) { @@ -12818,8 +12913,13 @@ static std::string ggml_backend_vk_get_device_pci_id(int device_idx) {
} }
} }
@ -419,7 +419,7 @@ index bd3ece516..7cfb14a54 100644
} }
vk::PhysicalDeviceProperties2 props = {}; vk::PhysicalDeviceProperties2 props = {};
@@ -12467,19 +12567,24 @@ static std::string ggml_backend_vk_get_device_pci_id(int device_idx) { @@ -12836,19 +12936,24 @@ static std::string ggml_backend_vk_get_device_pci_id(int device_idx) {
char pci_bus_id[16] = {}; char pci_bus_id[16] = {};
snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.%x", pci_domain, pci_bus, pci_device, pci_function); snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.%x", pci_domain, pci_bus, pci_device, pci_function);
@ -453,7 +453,7 @@ index bd3ece516..7cfb14a54 100644
static const char * ggml_backend_vk_device_get_name(ggml_backend_dev_t dev) { static const char * ggml_backend_vk_device_get_name(ggml_backend_dev_t dev) {
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
@@ -12491,9 +12596,14 @@ static const char * ggml_backend_vk_device_get_description(ggml_backend_dev_t de @@ -12860,9 +12965,14 @@ static const char * ggml_backend_vk_device_get_description(ggml_backend_dev_t de
return ctx->description.c_str(); return ctx->description.c_str();
} }
@ -469,7 +469,7 @@ index bd3ece516..7cfb14a54 100644
} }
static ggml_backend_buffer_type_t ggml_backend_vk_device_get_buffer_type(ggml_backend_dev_t dev) { static ggml_backend_buffer_type_t ggml_backend_vk_device_get_buffer_type(ggml_backend_dev_t dev) {
@@ -12517,8 +12627,9 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml @@ -12886,8 +12996,9 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml
props->name = ggml_backend_vk_device_get_name(dev); props->name = ggml_backend_vk_device_get_name(dev);
props->description = ggml_backend_vk_device_get_description(dev); props->description = ggml_backend_vk_device_get_description(dev);
@ -480,7 +480,7 @@ index bd3ece516..7cfb14a54 100644
ggml_backend_vk_device_get_memory(dev, &props->memory_free, &props->memory_total); ggml_backend_vk_device_get_memory(dev, &props->memory_free, &props->memory_total);
props->caps = { props->caps = {
/* .async = */ false, /* .async = */ false,
@@ -12526,6 +12637,13 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml @@ -12895,6 +13006,13 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml
/* .buffer_from_host_ptr = */ false, /* .buffer_from_host_ptr = */ false,
/* .events = */ false, /* .events = */ false,
}; };
@ -494,7 +494,7 @@ index bd3ece516..7cfb14a54 100644
} }
static ggml_backend_t ggml_backend_vk_device_init(ggml_backend_dev_t dev, const char * params) { static ggml_backend_t ggml_backend_vk_device_init(ggml_backend_dev_t dev, const char * params) {
@@ -12954,6 +13072,8 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, @@ -13365,6 +13483,8 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg,
static std::mutex mutex; static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex); std::lock_guard<std::mutex> lock(mutex);
if (!initialized) { if (!initialized) {
@ -503,7 +503,7 @@ index bd3ece516..7cfb14a54 100644
for (int i = 0; i < ggml_backend_vk_get_device_count(); i++) { for (int i = 0; i < ggml_backend_vk_get_device_count(); i++) {
ggml_backend_vk_device_context * ctx = new ggml_backend_vk_device_context; ggml_backend_vk_device_context * ctx = new ggml_backend_vk_device_context;
char desc[256]; char desc[256];
@@ -12962,12 +13082,41 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, @@ -13373,12 +13493,41 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg,
ctx->name = GGML_VK_NAME + std::to_string(i); ctx->name = GGML_VK_NAME + std::to_string(i);
ctx->description = desc; ctx->description = desc;
ctx->is_integrated_gpu = ggml_backend_vk_get_device_type(i) == vk::PhysicalDeviceType::eIntegratedGpu; ctx->is_integrated_gpu = ggml_backend_vk_get_device_type(i) == vk::PhysicalDeviceType::eIntegratedGpu;

View File

@ -1,49 +0,0 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Julius Tischbein <ju.tischbein@gmail.com>
Date: Wed, 15 Oct 2025 13:54:15 +0200
Subject: [PATCH] CUDA: Changing the CUDA scheduling strategy to spin (#16585)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
* CUDA set scheduling strategy to spinning for cc121
* Using prop.major and prop.minor, include HIP and MUSA
* Exclude HIP and MUSA
* Remove trailing whitespace
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
* Remove empty line
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
---------
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
---
ggml/src/ggml-cuda/ggml-cuda.cu | 9 +++++++++
1 file changed, 9 insertions(+)
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index b075a18be..d62f412d6 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -340,6 +340,15 @@ static ggml_cuda_device_info ggml_cuda_init() {
} else if (device_name.substr(0, 21) == "NVIDIA GeForce GTX 16") {
turing_devices_without_mma.push_back({ id, device_name });
}
+
+ // Temporary performance fix:
+ // Setting device scheduling strategy for iGPUs with cc121 to "spinning" to avoid delays in cuda synchronize calls.
+ // TODO: Check for future drivers the default scheduling strategy and
+ // remove this call again when cudaDeviceScheduleSpin is default.
+ if (prop.major == 12 && prop.minor == 1) {
+ CUDA_CHECK(cudaSetDeviceFlags(cudaDeviceScheduleSpin));
+ }
+
#endif // defined(GGML_USE_HIP)
}

View File

@ -8,10 +8,10 @@ Subject: [PATCH] report LoadLibrary failures
1 file changed, 12 insertions(+) 1 file changed, 12 insertions(+)
diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp
index f794d9cfa..3a855ab2e 100644 index a55d9b280..ec6f7f1e9 100644
--- a/ggml/src/ggml-backend-reg.cpp --- a/ggml/src/ggml-backend-reg.cpp
+++ b/ggml/src/ggml-backend-reg.cpp +++ b/ggml/src/ggml-backend-reg.cpp
@@ -118,6 +118,18 @@ static dl_handle * dl_load_library(const fs::path & path) { @@ -122,6 +122,18 @@ static dl_handle * dl_load_library(const fs::path & path) {
SetErrorMode(old_mode | SEM_FAILCRITICALERRORS); SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
HMODULE handle = LoadLibraryW(path.wstring().c_str()); HMODULE handle = LoadLibraryW(path.wstring().c_str());

View File

@ -13,7 +13,7 @@ interleaved version used for qwen3vl
4 files changed, 11 insertions(+), 30 deletions(-) 4 files changed, 11 insertions(+), 30 deletions(-)
diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp
index 31478dd8e..4d1ed207e 100644 index 902fdad69..70955347d 100644
--- a/ggml/src/ggml-cpu/ops.cpp --- a/ggml/src/ggml-cpu/ops.cpp
+++ b/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp
@@ -5509,15 +5509,12 @@ static void ggml_mrope_cache_init( @@ -5509,15 +5509,12 @@ static void ggml_mrope_cache_init(
@ -62,10 +62,10 @@ index d058504cd..287fe9d2c 100644
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
index 375a0c7fd..9866c96b4 100644 index 50b8071de..65a3183c8 100644
--- a/ggml/src/ggml-metal/ggml-metal.metal --- a/ggml/src/ggml-metal/ggml-metal.metal
+++ b/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal
@@ -3858,15 +3858,11 @@ kernel void kernel_rope_multi( @@ -3888,15 +3888,11 @@ kernel void kernel_rope_multi(
const int sec_w012 = args.sect_0 + args.sect_1 + args.sect_2; // end of section 2 const int sec_w012 = args.sect_0 + args.sect_1 + args.sect_2; // end of section 2
const int sector = ic % sect_dims; const int sector = ic % sect_dims;

View File

@ -12,7 +12,7 @@ Subject: [PATCH] Add memory detection using DXGI + PDH
create mode 100644 ggml/src/mem_dxgi_pdh.cpp create mode 100644 ggml/src/mem_dxgi_pdh.cpp
diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt
index aefe43bdd..21fe4640c 100644 index 03f359ae9..4b3e5efb5 100644
--- a/ggml/src/CMakeLists.txt --- a/ggml/src/CMakeLists.txt
+++ b/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt
@@ -211,6 +211,7 @@ add_library(ggml-base @@ -211,6 +211,7 @@ add_library(ggml-base
@ -24,10 +24,10 @@ index aefe43bdd..21fe4640c 100644
target_include_directories(ggml-base PRIVATE .) target_include_directories(ggml-base PRIVATE .)
diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h
index b63edd0c1..81cad8cf3 100644 index 44ae76d66..639d551a2 100644
--- a/ggml/src/ggml-impl.h --- a/ggml/src/ggml-impl.h
+++ b/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h
@@ -645,6 +645,9 @@ GGML_API void ggml_nvml_release(); @@ -684,6 +684,9 @@ GGML_API void ggml_nvml_release();
GGML_API int ggml_hip_mgmt_init(); GGML_API int ggml_hip_mgmt_init();
GGML_API int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total); GGML_API int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total);
GGML_API void ggml_hip_mgmt_release(); GGML_API void ggml_hip_mgmt_release();
@ -38,7 +38,7 @@ index b63edd0c1..81cad8cf3 100644
#ifdef __cplusplus #ifdef __cplusplus
} }
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index 7cfb14a54..a1c46d0b3 100644 index d2c278a35..221e29509 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -73,6 +73,7 @@ DispatchLoaderDynamic & ggml_vk_default_dispatcher(); @@ -73,6 +73,7 @@ DispatchLoaderDynamic & ggml_vk_default_dispatcher();
@ -49,7 +49,7 @@ index 7cfb14a54..a1c46d0b3 100644
typedef struct VkPhysicalDeviceShaderBfloat16FeaturesKHR { typedef struct VkPhysicalDeviceShaderBfloat16FeaturesKHR {
VkStructureType sType; VkStructureType sType;
@@ -12433,6 +12434,7 @@ struct ggml_backend_vk_device_context { @@ -12802,6 +12803,7 @@ struct ggml_backend_vk_device_context {
std::string pci_id; std::string pci_id;
std::string id; std::string id;
std::string uuid; std::string uuid;
@ -57,7 +57,7 @@ index 7cfb14a54..a1c46d0b3 100644
int major; int major;
int minor; int minor;
int driver_major; int driver_major;
@@ -12448,8 +12450,22 @@ void ggml_backend_vk_get_device_memory(ggml_backend_vk_device_context *ctx, size @@ -12817,8 +12819,22 @@ void ggml_backend_vk_get_device_memory(ggml_backend_vk_device_context *ctx, size
vk::PhysicalDeviceMemoryProperties memprops = vkdev.getMemoryProperties(); vk::PhysicalDeviceMemoryProperties memprops = vkdev.getMemoryProperties();
vk::PhysicalDeviceProperties2 props2; vk::PhysicalDeviceProperties2 props2;
vkdev.getProperties2(&props2); vkdev.getProperties2(&props2);
@ -81,7 +81,7 @@ index 7cfb14a54..a1c46d0b3 100644
{ {
// Use vendor specific management libraries for best VRAM reporting if available // Use vendor specific management libraries for best VRAM reporting if available
switch (props2.properties.vendorID) { switch (props2.properties.vendorID) {
@@ -12477,8 +12493,8 @@ void ggml_backend_vk_get_device_memory(ggml_backend_vk_device_context *ctx, size @@ -12846,8 +12862,8 @@ void ggml_backend_vk_get_device_memory(ggml_backend_vk_device_context *ctx, size
break; break;
} }
} }
@ -91,7 +91,7 @@ index 7cfb14a54..a1c46d0b3 100644
*total = 0; *total = 0;
*free = 0; *free = 0;
vk::PhysicalDeviceMemoryBudgetPropertiesEXT mem_budget_props; vk::PhysicalDeviceMemoryBudgetPropertiesEXT mem_budget_props;
@@ -13089,7 +13105,6 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, @@ -13500,7 +13516,6 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg,
/* .reg = */ reg, /* .reg = */ reg,
/* .context = */ ctx, /* .context = */ ctx,
}); });
@ -99,7 +99,7 @@ index 7cfb14a54..a1c46d0b3 100644
// Gather additional information about the device // Gather additional information about the device
int dev_idx = vk_instance.device_indices[i]; int dev_idx = vk_instance.device_indices[i];
vk::PhysicalDeviceProperties props1; vk::PhysicalDeviceProperties props1;
@@ -13112,6 +13127,14 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, @@ -13523,6 +13538,14 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg,
} }
} }
ctx->uuid = oss.str(); ctx->uuid = oss.str();

View File

@ -0,0 +1,19 @@
#pragma once
#include "ggml.h"
#include "ggml-backend.h"
#ifdef __cplusplus
extern "C" {
#endif
// backend API
GGML_BACKEND_API ggml_backend_t ggml_backend_hexagon_init(void);
GGML_BACKEND_API bool ggml_backend_is_hexagon(ggml_backend_t backend);
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_hexagon_reg(void);
#ifdef __cplusplus
}
#endif

View File

@ -21,8 +21,7 @@ GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const c
GGML_BACKEND_API void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device, size_t * free, size_t * total); GGML_BACKEND_API void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device, size_t * free, size_t * total);
GGML_BACKEND_API void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir, GGML_BACKEND_API void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir,
size_t n_threads, size_t n_devices, size_t n_threads, size_t n_devices, ggml_backend_dev_t * devices);
ggml_backend_dev_t * devices, size_t * free_mem, size_t * total_mem);
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_reg(void); GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_reg(void);
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_add_server(const char * endpoint); GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_add_server(const char * endpoint);

View File

@ -577,6 +577,10 @@ extern "C" {
GGML_UNARY_OP_EXP, GGML_UNARY_OP_EXP,
GGML_UNARY_OP_GELU_ERF, GGML_UNARY_OP_GELU_ERF,
GGML_UNARY_OP_XIELU, GGML_UNARY_OP_XIELU,
GGML_UNARY_OP_FLOOR,
GGML_UNARY_OP_CEIL,
GGML_UNARY_OP_ROUND,
GGML_UNARY_OP_TRUNC,
GGML_UNARY_OP_COUNT, GGML_UNARY_OP_COUNT,
}; };
@ -1151,6 +1155,46 @@ extern "C" {
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a); struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_floor(
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_floor_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_ceil(
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_ceil_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_round(
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_round_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a);
/**
* Truncates the fractional part of each element in the tensor (towards zero).
* For example: trunc(3.7) = 3.0, trunc(-2.9) = -2.0
* Similar to std::trunc in C/C++.
*/
GGML_API struct ggml_tensor * ggml_trunc(
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_trunc_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a);
// xIELU activation function // xIELU activation function
// x = x * (c_a(alpha_n) + c_b(alpha_p, beta) * sigmoid(beta * x)) + eps * (x > 0) // x = x * (c_a(alpha_n) + c_b(alpha_p, beta) * sigmoid(beta * x)) + eps * (x > 0)
// where c_a = softplus and c_b(a, b) = softplus(a) + b are constraining functions // where c_a = softplus and c_b(a, b) = softplus(a) + b are constraining functions

View File

@ -310,6 +310,10 @@ function(ggml_add_cpu_backend_variant tag_name)
foreach (feat ${ARGN}) foreach (feat ${ARGN})
set(GGML_INTERNAL_${feat} ON) set(GGML_INTERNAL_${feat} ON)
endforeach() endforeach()
elseif (GGML_SYSTEM_ARCH STREQUAL "s390x")
foreach (feat ${ARGN})
set(GGML_INTERNAL_${feat} ON)
endforeach()
endif() endif()
ggml_add_cpu_backend_variant_impl(${tag_name}) ggml_add_cpu_backend_variant_impl(${tag_name})
@ -372,6 +376,14 @@ if (GGML_CPU_ALL_VARIANTS)
else() else()
message(FATAL_ERROR "Unsupported PowerPC target OS: ${CMAKE_SYSTEM_NAME}") message(FATAL_ERROR "Unsupported PowerPC target OS: ${CMAKE_SYSTEM_NAME}")
endif() endif()
elseif (GGML_SYSTEM_ARCH STREQUAL "s390x")
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
ggml_add_cpu_backend_variant(s390x_z15 Z15 VXE)
# ggml_add_cpu_backend_variant(s390x_z16 Z16 VXE)
# ggml_add_cpu_backend_variant(s390x_z17 Z17 VXE)
else()
message(FATAL_ERROR "Unsupported s390x target OS: ${CMAKE_SYSTEM_NAME}")
endif()
else() else()
message(FATAL_ERROR "GGML_CPU_ALL_VARIANTS not yet supported with ${GGML_SYSTEM_ARCH} on ${CMAKE_SYSTEM_NAME}") message(FATAL_ERROR "GGML_CPU_ALL_VARIANTS not yet supported with ${GGML_SYSTEM_ARCH} on ${CMAKE_SYSTEM_NAME}")
endif() endif()
@ -391,6 +403,7 @@ ggml_add_backend(Vulkan)
ggml_add_backend(WebGPU) ggml_add_backend(WebGPU)
ggml_add_backend(zDNN) ggml_add_backend(zDNN)
ggml_add_backend(OpenCL) ggml_add_backend(OpenCL)
ggml_add_backend(Hexagon)
foreach (target ggml-base ggml) foreach (target ggml-base ggml)
target_include_directories(${target} PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/../include> $<INSTALL_INTERFACE:include>) target_include_directories(${target} PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/../include> $<INSTALL_INTERFACE:include>)

View File

@ -603,6 +603,26 @@ static bool ggml_gallocr_is_allocated(ggml_gallocr_t galloc, struct ggml_tensor
return t->data != NULL || ggml_gallocr_hash_get(galloc, t)->allocated; return t->data != NULL || ggml_gallocr_hash_get(galloc, t)->allocated;
} }
// free the extra space at the end if the new tensor is smaller
static void ggml_gallocr_free_extra_space(ggml_gallocr_t galloc, struct ggml_tensor * node, struct ggml_tensor * parent) {
struct hash_node * hn = ggml_gallocr_hash_get(galloc, node);
struct hash_node * p_hn = ggml_gallocr_hash_get(galloc, parent);
size_t parent_size = ggml_backend_buft_get_alloc_size(galloc->bufts[p_hn->buffer_id], parent);
size_t node_size = ggml_backend_buft_get_alloc_size(galloc->bufts[hn->buffer_id], node);
GGML_ASSERT(parent_size >= node_size);
if (parent_size > node_size) {
struct ggml_dyn_tallocr * p_alloc = galloc->buf_tallocs[p_hn->buffer_id];
struct buffer_address p_addr = p_hn->addr;
p_addr.offset += node_size;
size_t extra_size = parent_size - node_size;
AT_PRINTF("freeing extra %zu bytes from parent %s for %s\n", extra_size, parent->name, node->name);
ggml_dyn_tallocr_free_tensor(p_alloc, p_addr, extra_size, parent);
}
}
static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor * node, int buffer_id) { static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor * node, int buffer_id) {
GGML_ASSERT(buffer_id >= 0); GGML_ASSERT(buffer_id >= 0);
struct hash_node * hn = ggml_gallocr_hash_get(galloc, node); struct hash_node * hn = ggml_gallocr_hash_get(galloc, node);
@ -648,6 +668,7 @@ static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor
hn->addr = p_hn->addr; hn->addr = p_hn->addr;
p_hn->allocated = false; // avoid freeing the parent p_hn->allocated = false; // avoid freeing the parent
view_src_hn->allocated = false; view_src_hn->allocated = false;
ggml_gallocr_free_extra_space(galloc, node, view_src);
return; return;
} }
} else { } else {
@ -655,6 +676,7 @@ static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor
hn->buffer_id = p_hn->buffer_id; hn->buffer_id = p_hn->buffer_id;
hn->addr = p_hn->addr; hn->addr = p_hn->addr;
p_hn->allocated = false; // avoid freeing the parent p_hn->allocated = false; // avoid freeing the parent
ggml_gallocr_free_extra_space(galloc, node, parent);
return; return;
} }
} }

View File

@ -57,6 +57,10 @@
#include "ggml-opencl.h" #include "ggml-opencl.h"
#endif #endif
#ifdef GGML_USE_HEXAGON
#include "ggml-hexagon.h"
#endif
#ifdef GGML_USE_BLAS #ifdef GGML_USE_BLAS
#include "ggml-blas.h" #include "ggml-blas.h"
#endif #endif
@ -211,6 +215,9 @@ struct ggml_backend_registry {
#ifdef GGML_USE_OPENCL #ifdef GGML_USE_OPENCL
register_backend(ggml_backend_opencl_reg()); register_backend(ggml_backend_opencl_reg());
#endif #endif
#ifdef GGML_USE_HEXAGON
register_backend(ggml_backend_hexagon_reg());
#endif
#ifdef GGML_USE_CANN #ifdef GGML_USE_CANN
register_backend(ggml_backend_cann_reg()); register_backend(ggml_backend_cann_reg());
#endif #endif
@ -615,6 +622,7 @@ void ggml_backend_load_all_from_path(const char * dir_path) {
ggml_backend_load_best("sycl", silent, dir_path); ggml_backend_load_best("sycl", silent, dir_path);
ggml_backend_load_best("vulkan", silent, dir_path); ggml_backend_load_best("vulkan", silent, dir_path);
ggml_backend_load_best("opencl", silent, dir_path); ggml_backend_load_best("opencl", silent, dir_path);
ggml_backend_load_best("hexagon", silent, dir_path);
ggml_backend_load_best("musa", silent, dir_path); ggml_backend_load_best("musa", silent, dir_path);
ggml_backend_load_best("cpu", silent, dir_path); ggml_backend_load_best("cpu", silent, dir_path);
// check the environment variable GGML_BACKEND_PATH to load an out-of-tree backend // check the environment variable GGML_BACKEND_PATH to load an out-of-tree backend

View File

@ -466,7 +466,12 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
list(APPEND ARCH_FLAGS "-march=${MARCH_STR}" -mabi=lp64d) list(APPEND ARCH_FLAGS "-march=${MARCH_STR}" -mabi=lp64d)
elseif (GGML_SYSTEM_ARCH STREQUAL "s390x") elseif (GGML_SYSTEM_ARCH STREQUAL "s390x")
message(STATUS "s390x detected") message(STATUS "s390x detected")
list(APPEND GGML_CPU_SOURCES ggml-cpu/arch/s390/quants.c) list(APPEND GGML_CPU_SOURCES
ggml-cpu/arch/s390/quants.c)
# for native compilation
if (GGML_NATIVE)
# check machine level to determine target
file(READ "/proc/cpuinfo" CPUINFO_CONTENTS) file(READ "/proc/cpuinfo" CPUINFO_CONTENTS)
string(REGEX REPLACE "machine[ \t\r\n]*=[ \t\r\n]*([0-9]+)" "\\1" S390X_M ${CPUINFO_CONTENTS}) string(REGEX REPLACE "machine[ \t\r\n]*=[ \t\r\n]*([0-9]+)" "\\1" S390X_M ${CPUINFO_CONTENTS})
@ -487,8 +492,19 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
message(WARNING "Unknown target. If you are compiling for z14 and earlier, you might have to add -DGGML_VXE=OFF.") message(WARNING "Unknown target. If you are compiling for z14 and earlier, you might have to add -DGGML_VXE=OFF.")
list(APPEND ARCH_FLAGS -march=native -mtune=native) list(APPEND ARCH_FLAGS -march=native -mtune=native)
endif() endif()
# for cross-compilation
elseif(GGML_CPU_ALL_VARIANTS)
# range through IBM z15 to z17
# NOTE: update when a new hardware level is released
foreach (ZHW RANGE 15 17)
if(DEFINED GGML_INTERNAL_Z${ZHW})
message(STATUS "z${ZHW} cross-compile target")
list(APPEND ARCH_FLAGS -march=z${ZHW})
endif()
endforeach()
endif()
if (GGML_VXE) if (GGML_VXE OR GGML_INTERNAL_VXE)
message(STATUS "VX/VXE/VXE2 enabled") message(STATUS "VX/VXE/VXE2 enabled")
list(APPEND ARCH_FLAGS -mvx -mzvector) list(APPEND ARCH_FLAGS -mvx -mzvector)
list(APPEND ARCH_DEFINITIONS GGML_VXE) list(APPEND ARCH_DEFINITIONS GGML_VXE)

View File

@ -2186,6 +2186,10 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_UNARY_OP_HARDSWISH: case GGML_UNARY_OP_HARDSWISH:
case GGML_UNARY_OP_HARDSIGMOID: case GGML_UNARY_OP_HARDSIGMOID:
case GGML_UNARY_OP_EXP: case GGML_UNARY_OP_EXP:
case GGML_UNARY_OP_FLOOR:
case GGML_UNARY_OP_CEIL:
case GGML_UNARY_OP_ROUND:
case GGML_UNARY_OP_TRUNC:
{ {
n_tasks = 1; n_tasks = 1;
} break; } break;
@ -3569,13 +3573,17 @@ void ggml_cpu_init(void) {
#ifdef GGML_USE_OPENMP #ifdef GGML_USE_OPENMP
//if (!getenv("OMP_WAIT_POLICY")) { //if (!getenv("OMP_WAIT_POLICY")) {
// // set the wait policy to active, so that OpenMP threads don't sleep // // set the wait policy to active, so that OpenMP threads don't sleep
// putenv("OMP_WAIT_POLICY=active"); // setenv("OMP_WAIT_POLICY", "active", 0)
//} //}
if (!getenv("KMP_BLOCKTIME")) { if (!getenv("KMP_BLOCKTIME")) {
// set the time to wait before sleeping a thread // set the time to wait before sleeping a thread
// this is less aggressive than setting the wait policy to active, but should achieve similar results in most cases // this is less aggressive than setting the wait policy to active, but should achieve similar results in most cases
putenv("KMP_BLOCKTIME=200"); // 200ms #ifdef _WIN32
_putenv_s("KMP_BLOCKTIME", "200"); // 200ms
#else
setenv("KMP_BLOCKTIME", "200", 0); // 200ms
#endif
} }
#endif #endif
} }

View File

@ -9033,6 +9033,22 @@ void ggml_compute_forward_unary(
{ {
ggml_compute_forward_exp(params, dst); ggml_compute_forward_exp(params, dst);
} break; } break;
case GGML_UNARY_OP_FLOOR:
{
ggml_compute_forward_floor(params, dst);
} break;
case GGML_UNARY_OP_CEIL:
{
ggml_compute_forward_ceil(params, dst);
} break;
case GGML_UNARY_OP_ROUND:
{
ggml_compute_forward_round(params, dst);
} break;
case GGML_UNARY_OP_TRUNC:
{
ggml_compute_forward_trunc(params, dst);
} break;
case GGML_UNARY_OP_XIELU: case GGML_UNARY_OP_XIELU:
{ {
ggml_compute_forward_xielu(params, dst); ggml_compute_forward_xielu(params, dst);

View File

@ -73,6 +73,22 @@ static inline float op_log(float x) {
return logf(x); return logf(x);
} }
static inline float op_floor(float x) {
return floorf(x);
}
static inline float op_ceil(float x) {
return ceilf(x);
}
static inline float op_round(float x) {
return roundf(x);
}
static inline float op_trunc(float x) {
return truncf(x);
}
template <float (*op)(float), typename src0_t, typename dst_t> template <float (*op)(float), typename src0_t, typename dst_t>
static inline void vec_unary_op(int64_t n, dst_t * y, const src0_t * x) { static inline void vec_unary_op(int64_t n, dst_t * y, const src0_t * x) {
constexpr auto src0_to_f32 = type_conversion_table<src0_t>::to_f32; constexpr auto src0_to_f32 = type_conversion_table<src0_t>::to_f32;
@ -274,6 +290,22 @@ void ggml_compute_forward_log(const ggml_compute_params * params, ggml_tensor *
unary_op<op_log>(params, dst); unary_op<op_log>(params, dst);
} }
void ggml_compute_forward_floor(const ggml_compute_params * params, ggml_tensor * dst) {
unary_op<op_floor>(params, dst);
}
void ggml_compute_forward_ceil(const ggml_compute_params * params, ggml_tensor * dst) {
unary_op<op_ceil>(params, dst);
}
void ggml_compute_forward_round(const ggml_compute_params * params, ggml_tensor * dst) {
unary_op<op_round>(params, dst);
}
void ggml_compute_forward_trunc(const ggml_compute_params * params, ggml_tensor * dst) {
unary_op<op_trunc>(params, dst);
}
void ggml_compute_forward_xielu(const ggml_compute_params * params, ggml_tensor * dst) { void ggml_compute_forward_xielu(const ggml_compute_params * params, ggml_tensor * dst) {
const float alpha_n = ggml_get_op_params_f32(dst, 1); const float alpha_n = ggml_get_op_params_f32(dst, 1);
const float alpha_p = ggml_get_op_params_f32(dst, 2); const float alpha_p = ggml_get_op_params_f32(dst, 2);

View File

@ -22,6 +22,10 @@ void ggml_compute_forward_sqrt(const struct ggml_compute_params * params, struct
void ggml_compute_forward_sin(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_sin(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_cos(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_cos(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_log(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_log(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_floor(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_ceil(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_round(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_trunc(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_xielu(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_xielu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
#ifdef __cplusplus #ifdef __cplusplus

View File

@ -1,5 +1,81 @@
#include "argsort.cuh" #include "argsort.cuh"
#ifdef GGML_CUDA_USE_CUB
# include <cub/cub.cuh>
using namespace cub;
#endif // GGML_CUDA_USE_CUB
static __global__ void init_indices(int * indices, const int ncols, const int nrows) {
const int col = blockIdx.x * blockDim.x + threadIdx.x;
const int row = blockIdx.y;
if (col < ncols && row < nrows) {
indices[row * ncols + col] = col;
}
}
static __global__ void init_offsets(int * offsets, const int ncols, const int nrows) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx <= nrows) {
offsets[idx] = idx * ncols;
}
}
#ifdef GGML_CUDA_USE_CUB
static void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
const float * x,
int * dst,
const int ncols,
const int nrows,
ggml_sort_order order,
cudaStream_t stream) {
ggml_cuda_pool_alloc<int> temp_indices_alloc(pool, ncols * nrows);
ggml_cuda_pool_alloc<float> temp_keys_alloc(pool, ncols * nrows);
ggml_cuda_pool_alloc<int> offsets_alloc(pool, nrows + 1);
int * temp_indices = temp_indices_alloc.get();
float * temp_keys = temp_keys_alloc.get();
int * d_offsets = offsets_alloc.get();
static const int block_size = 256;
const dim3 grid_size((ncols + block_size - 1) / block_size, nrows);
init_indices<<<grid_size, block_size, 0, stream>>>(temp_indices, ncols, nrows);
const dim3 offset_grid((nrows + block_size - 1) / block_size);
init_offsets<<<offset_grid, block_size, 0, stream>>>(d_offsets, ncols, nrows);
cudaMemcpyAsync(temp_keys, x, ncols * nrows * sizeof(float), cudaMemcpyDeviceToDevice, stream);
size_t temp_storage_bytes = 0;
if (order == GGML_SORT_ORDER_ASC) {
DeviceSegmentedRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
temp_indices, dst, // values (indices)
ncols * nrows, nrows, // num items, num segments
d_offsets, d_offsets + 1, 0, sizeof(float) * 8, // all bits
stream);
} else {
DeviceSegmentedRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices,
dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, 0,
sizeof(float) * 8, stream);
}
ggml_cuda_pool_alloc<uint8_t> temp_storage_alloc(pool, temp_storage_bytes);
void * d_temp_storage = temp_storage_alloc.get();
if (order == GGML_SORT_ORDER_ASC) {
DeviceSegmentedRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst,
ncols * nrows, nrows, d_offsets, d_offsets + 1, 0, sizeof(float) * 8,
stream);
} else {
DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
temp_indices, dst, ncols * nrows, nrows, d_offsets, d_offsets + 1,
0, sizeof(float) * 8, stream);
}
}
#endif // GGML_CUDA_USE_CUB
// Bitonic sort implementation
template<typename T> template<typename T>
static inline __device__ void ggml_cuda_swap(T & a, T & b) { static inline __device__ void ggml_cuda_swap(T & a, T & b) {
T tmp = a; T tmp = a;
@ -65,7 +141,12 @@ static int next_power_of_2(int x) {
return n; return n;
} }
static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, cudaStream_t stream) { static void argsort_f32_i32_cuda_bitonic(const float * x,
int * dst,
const int ncols,
const int nrows,
ggml_sort_order order,
cudaStream_t stream) {
// bitonic sort requires ncols to be power of 2 // bitonic sort requires ncols to be power of 2
const int ncols_pad = next_power_of_2(ncols); const int ncols_pad = next_power_of_2(ncols);
@ -77,9 +158,11 @@ static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, co
GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb); GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb);
if (order == GGML_SORT_ORDER_ASC) { if (order == GGML_SORT_ORDER_ASC) {
k_argsort_f32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad); k_argsort_f32_i32<GGML_SORT_ORDER_ASC>
<<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
} else if (order == GGML_SORT_ORDER_DESC) { } else if (order == GGML_SORT_ORDER_DESC) {
k_argsort_f32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad); k_argsort_f32_i32<GGML_SORT_ORDER_DESC>
<<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
} else { } else {
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
} }
@ -197,6 +280,19 @@ void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
if (src0->type == GGML_TYPE_I32) { if (src0->type == GGML_TYPE_I32) {
argsort_i32_i32_cuda((const int32_t *)src0_d, (int *)dst_d, ncols, nrows, order, stream); argsort_i32_i32_cuda((const int32_t *)src0_d, (int *)dst_d, ncols, nrows, order, stream);
} else { } else {
argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, order, stream); #ifdef GGML_CUDA_USE_CUB
const int ncols_pad = next_power_of_2(ncols);
const size_t shared_mem = ncols_pad * sizeof(int);
const size_t max_shared_mem = ggml_cuda_info().devices[ggml_cuda_get_device()].smpb;
if (shared_mem > max_shared_mem || ncols > 1024) {
ggml_cuda_pool & pool = ctx.pool();
argsort_f32_i32_cuda_cub(pool, src0_d, (int *) dst_d, ncols, nrows, order, stream);
} else {
argsort_f32_i32_cuda_bitonic(src0_d, (int *) dst_d, ncols, nrows, order, stream);
}
#else
argsort_f32_i32_cuda_bitonic(src0_d, (int *) dst_d, ncols, nrows, order, stream);
#endif
} }
} }

View File

@ -272,7 +272,7 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
const uint3 ne12 = init_fastdiv_values((uint32_t) cne1[2]); const uint3 ne12 = init_fastdiv_values((uint32_t) cne1[2]);
const uint3 ne13 = init_fastdiv_values((uint32_t) cne1[3]); const uint3 ne13 = init_fastdiv_values((uint32_t) cne1[3]);
if (block_nums.z > 65535) { if (block_nums.z > 65535 || block_nums.y > 65535) {
int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size; int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size;
const uint3 prod_012 = init_fastdiv_values((uint32_t) (ne0 * ne1 * ne2)); const uint3 prod_012 = init_fastdiv_values((uint32_t) (ne0 * ne1 * ne2));
const uint3 prod_01 = init_fastdiv_values((uint32_t) (ne0 * ne1)); const uint3 prod_01 = init_fastdiv_values((uint32_t) (ne0 * ne1));

View File

@ -982,13 +982,6 @@ struct ggml_cuda_graph {
bool disable_due_to_failed_graph_capture = false; bool disable_due_to_failed_graph_capture = false;
int number_consecutive_updates = 0; int number_consecutive_updates = 0;
std::vector<ggml_graph_node_properties> ggml_graph_properties; std::vector<ggml_graph_node_properties> ggml_graph_properties;
bool use_cpy_indirection = false;
std::vector<char *> cpy_dest_ptrs;
char ** dest_ptrs_d;
int dest_ptrs_size = 0;
// Index to allow each cpy kernel to be aware of it's position within the graph
// relative to other cpy nodes.
int graph_cpynode_index = -1;
#endif #endif
}; };

View File

@ -8,18 +8,16 @@
typedef void (*cpy_kernel_t)(const char * cx, char * cdst); typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
template <cpy_kernel_t cpy_1> template <cpy_kernel_t cpy_1>
static __global__ void cpy_flt(const char * cx, char * cdst_direct, const int ne, static __global__ void cpy_flt(const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) { const int nb12, const int nb13) {
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x; const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= ne) { if (i >= ne) {
return; return;
} }
char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct;
// determine indices i03/i13, i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor // determine indices i03/i13, i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
// then combine those indices with the corresponding byte offsets to get the total offsets // then combine those indices with the corresponding byte offsets to get the total offsets
const int64_t i03 = i/(ne00 * ne01 * ne02); const int64_t i03 = i/(ne00 * ne01 * ne02);
@ -63,18 +61,16 @@ static __device__ void cpy_blck_q_f32(const char * cxi, char * cdsti) {
} }
template <cpy_kernel_t cpy_blck, int qk> template <cpy_kernel_t cpy_blck, int qk>
static __global__ void cpy_f32_q(const char * cx, char * cdst_direct, const int ne, static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) { const int nb12, const int nb13) {
const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk; const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
if (i >= ne) { if (i >= ne) {
return; return;
} }
char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct;
const int i03 = i/(ne00 * ne01 * ne02); const int i03 = i/(ne00 * ne01 * ne02);
const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01); const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00; const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
@ -91,18 +87,16 @@ static __global__ void cpy_f32_q(const char * cx, char * cdst_direct, const int
} }
template <cpy_kernel_t cpy_blck, int qk> template <cpy_kernel_t cpy_blck, int qk>
static __global__ void cpy_q_f32(const char * cx, char * cdst_direct, const int ne, static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) { const int nb12, const int nb13) {
const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk; const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
if (i >= ne) { if (i >= ne) {
return; return;
} }
char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct;
const int i03 = i/(ne00 * ne01 * ne02); const int i03 = i/(ne00 * ne01 * ne02);
const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01); const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00; const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
@ -118,67 +112,47 @@ static __global__ void cpy_q_f32(const char * cx, char * cdst_direct, const int
cpy_blck(cx + x_offset, cdst + dst_offset); cpy_blck(cx + x_offset, cdst + dst_offset);
} }
// Copy destination pointers to GPU to be available when pointer indirection is in use
void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_dest_ptrs, const int host_dest_ptrs_size, cudaStream_t stream) {
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS)
if (cuda_graph->dest_ptrs_size < host_dest_ptrs_size) { // (re-)allocate GPU memory for destination pointers
CUDA_CHECK(cudaStreamSynchronize(stream));
if (cuda_graph->dest_ptrs_d != nullptr) {
CUDA_CHECK(cudaFree(cuda_graph->dest_ptrs_d));
}
CUDA_CHECK(cudaMalloc(&cuda_graph->dest_ptrs_d, host_dest_ptrs_size*sizeof(char *)));
cuda_graph->dest_ptrs_size = host_dest_ptrs_size;
}
// copy destination pointers to GPU
CUDA_CHECK(cudaMemcpyAsync(cuda_graph->dest_ptrs_d, host_dest_ptrs, host_dest_ptrs_size*sizeof(char *), cudaMemcpyHostToDevice, stream));
cuda_graph->graph_cpynode_index = 0; // reset index
#else
GGML_UNUSED_VARS(cuda_graph, host_dest_ptrs, host_dest_ptrs_size, stream);
#endif
}
template<typename src_t, typename dst_t> template<typename src_t, typename dst_t>
static void ggml_cpy_flt_cuda( static void ggml_cpy_flt_cuda(
const char * cx, char * cdst, const int ne, const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
cpy_flt<cpy_1_flt<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>> cpy_flt<cpy_1_flt<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
} }
static void ggml_cpy_f32_q8_0_cuda( static void ggml_cpy_f32_q8_0_cuda(
const char * cx, char * cdst, const int ne, const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
GGML_ASSERT(ne % QK8_0 == 0); GGML_ASSERT(ne % QK8_0 == 0);
const int num_blocks = ne / QK8_0; const int num_blocks = ne / QK8_0;
cpy_f32_q<cpy_blck_f32_q8_0, QK8_0><<<num_blocks, 1, 0, stream>>> cpy_f32_q<cpy_blck_f32_q8_0, QK8_0><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
} }
static void ggml_cpy_q8_0_f32_cuda( static void ggml_cpy_q8_0_f32_cuda(
const char * cx, char * cdst, const int ne, const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
const int num_blocks = ne; const int num_blocks = ne;
cpy_q_f32<cpy_blck_q8_0_f32, QK8_0><<<num_blocks, 1, 0, stream>>> cpy_q_f32<cpy_blck_q8_0_f32, QK8_0><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
} }
static void ggml_cpy_f32_q4_0_cuda( static void ggml_cpy_f32_q4_0_cuda(
const char * cx, char * cdst, const int ne, const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
GGML_ASSERT(ne % QK4_0 == 0); GGML_ASSERT(ne % QK4_0 == 0);
const int num_blocks = ne / QK4_0; const int num_blocks = ne / QK4_0;
cpy_f32_q<cpy_blck_f32_q4_0, QK4_0><<<num_blocks, 1, 0, stream>>> cpy_f32_q<cpy_blck_f32_q4_0, QK4_0><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
} }
static void ggml_cpy_q4_0_f32_cuda( static void ggml_cpy_q4_0_f32_cuda(
@ -187,22 +161,22 @@ static void ggml_cpy_q4_0_f32_cuda(
const int nb00, const int nb01, const int nb02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb03, const int ne10, const int ne11, const int ne12,
const int nb10, const int nb11, const int nb12, const int nb13, const int nb10, const int nb11, const int nb12, const int nb13,
cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { cudaStream_t stream) {
const int num_blocks = ne; const int num_blocks = ne;
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0><<<num_blocks, 1, 0, stream>>>( cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0><<<num_blocks, 1, 0, stream>>>(
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); ne10, ne11, ne12, nb10, nb11, nb12, nb13);
} }
static void ggml_cpy_f32_q4_1_cuda( static void ggml_cpy_f32_q4_1_cuda(
const char * cx, char * cdst, const int ne, const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
GGML_ASSERT(ne % QK4_1 == 0); GGML_ASSERT(ne % QK4_1 == 0);
const int num_blocks = ne / QK4_1; const int num_blocks = ne / QK4_1;
cpy_f32_q<cpy_blck_f32_q4_1, QK4_1><<<num_blocks, 1, 0, stream>>> cpy_f32_q<cpy_blck_f32_q4_1, QK4_1><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
} }
static void ggml_cpy_q4_1_f32_cuda( static void ggml_cpy_q4_1_f32_cuda(
@ -211,22 +185,22 @@ static void ggml_cpy_q4_1_f32_cuda(
const int nb00, const int nb01, const int nb02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb03, const int ne10, const int ne11, const int ne12,
const int nb10, const int nb11, const int nb12, const int nb13, const int nb10, const int nb11, const int nb12, const int nb13,
cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { cudaStream_t stream) {
const int num_blocks = ne; const int num_blocks = ne;
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1><<<num_blocks, 1, 0, stream>>>( cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1><<<num_blocks, 1, 0, stream>>>(
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); ne10, ne11, ne12, nb10, nb11, nb12, nb13);
} }
static void ggml_cpy_f32_q5_0_cuda( static void ggml_cpy_f32_q5_0_cuda(
const char * cx, char * cdst, const int ne, const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
GGML_ASSERT(ne % QK5_0 == 0); GGML_ASSERT(ne % QK5_0 == 0);
const int num_blocks = ne / QK5_0; const int num_blocks = ne / QK5_0;
cpy_f32_q<cpy_blck_f32_q5_0, QK5_0><<<num_blocks, 1, 0, stream>>> cpy_f32_q<cpy_blck_f32_q5_0, QK5_0><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
} }
static void ggml_cpy_q5_0_f32_cuda( static void ggml_cpy_q5_0_f32_cuda(
@ -235,22 +209,22 @@ static void ggml_cpy_q5_0_f32_cuda(
const int nb00, const int nb01, const int nb02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb03, const int ne10, const int ne11, const int ne12,
const int nb10, const int nb11, const int nb12, const int nb13, const int nb10, const int nb11, const int nb12, const int nb13,
cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { cudaStream_t stream) {
const int num_blocks = ne; const int num_blocks = ne;
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0><<<num_blocks, 1, 0, stream>>>( cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0><<<num_blocks, 1, 0, stream>>>(
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); ne10, ne11, ne12, nb10, nb11, nb12, nb13);
} }
static void ggml_cpy_f32_q5_1_cuda( static void ggml_cpy_f32_q5_1_cuda(
const char * cx, char * cdst, const int ne, const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
GGML_ASSERT(ne % QK5_1 == 0); GGML_ASSERT(ne % QK5_1 == 0);
const int num_blocks = ne / QK5_1; const int num_blocks = ne / QK5_1;
cpy_f32_q<cpy_blck_f32_q5_1, QK5_1><<<num_blocks, 1, 0, stream>>> cpy_f32_q<cpy_blck_f32_q5_1, QK5_1><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
} }
static void ggml_cpy_q5_1_f32_cuda( static void ggml_cpy_q5_1_f32_cuda(
@ -259,30 +233,29 @@ static void ggml_cpy_q5_1_f32_cuda(
const int nb00, const int nb01, const int nb02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb03, const int ne10, const int ne11, const int ne12,
const int nb10, const int nb11, const int nb12, const int nb13, const int nb10, const int nb11, const int nb12, const int nb13,
cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { cudaStream_t stream) {
const int num_blocks = ne; const int num_blocks = ne;
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1><<<num_blocks, 1, 0, stream>>>( cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1><<<num_blocks, 1, 0, stream>>>(
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); ne10, ne11, ne12, nb10, nb11, nb12, nb13);
} }
static void ggml_cpy_f32_iq4_nl_cuda( static void ggml_cpy_f32_iq4_nl_cuda(
const char * cx, char * cdst, const int ne, const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
GGML_ASSERT(ne % QK4_NL == 0); GGML_ASSERT(ne % QK4_NL == 0);
const int num_blocks = ne / QK4_NL; const int num_blocks = ne / QK4_NL;
cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL><<<num_blocks, 1, 0, stream>>> cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
} }
template <cpy_kernel_t cpy_1> template <cpy_kernel_t cpy_1>
static __global__ void cpy_i32_i32( static __global__ void cpy_i32_i32(
const char *cx, char *cdst, const int ne, const char *cx, char *cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
const int64_t i = blockDim.x * blockIdx.x + threadIdx.x; const int64_t i = blockDim.x * blockIdx.x + threadIdx.x;
@ -302,23 +275,20 @@ static __global__ void cpy_i32_i32(
const int64_t i10 = i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11 - i11 * ne10; const int64_t i10 = i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11 - i11 * ne10;
const int64_t dst_offset = i10 * nb10 + i11 * nb11 + i12 * nb12 + i13 * nb13; const int64_t dst_offset = i10 * nb10 + i11 * nb11 + i12 * nb12 + i13 * nb13;
char * cdst_ptr = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index] : cdst; cpy_1(cx + x_offset, cdst + dst_offset);
cpy_1(cx + x_offset, cdst_ptr + dst_offset);
} }
static void ggml_cpy_i32_i32_cuda( static void ggml_cpy_i32_i32_cuda(
const char * cx, char * cdst, const int ne, const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
cudaStream_t stream, char ** cdst_indirect, int graph_cpynode_index) {
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
cpy_i32_i32<cpy_1_i32_i32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>> cpy_i32_i32<cpy_1_i32_i32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, stream, cdst_indirect, graph_cpynode_index); (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, stream);
} }
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection_for_this_node) { void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) {
const int64_t ne = ggml_nelements(src0); const int64_t ne = ggml_nelements(src0);
GGML_ASSERT(ne == ggml_nelements(src1)); GGML_ASSERT(ne == ggml_nelements(src1));
@ -352,16 +322,6 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
char * src0_ddc = (char *) src0->data; char * src0_ddc = (char *) src0->data;
char * src1_ddc = (char *) src1->data; char * src1_ddc = (char *) src1->data;
char ** dest_ptrs_d = nullptr;
int graph_cpynode_index = -1;
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS)
if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection_for_this_node) {
dest_ptrs_d = ctx.cuda_graph->dest_ptrs_d;
graph_cpynode_index = ctx.cuda_graph->graph_cpynode_index;
}
#else
GGML_UNUSED(disable_indirection_for_this_node);
#endif
if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) { if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1)); GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
#if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY) #if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY)
@ -370,136 +330,65 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
} else } else
#endif // GGML_USE_MUSA && GGML_MUSA_MUDNN_COPY #endif // GGML_USE_MUSA && GGML_MUSA_MUDNN_COPY
{ {
if (src0->type == GGML_TYPE_F32) {
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else {
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream)); CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
} }
}
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) { } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
ggml_cpy_flt_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); ggml_cpy_flt_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) { } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
ggml_cpy_flt_cuda<float, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); ggml_cpy_flt_cuda<float, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) { } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) { } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) { } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) { } else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
ggml_cpy_q4_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, ggml_cpy_q4_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) { } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) { } else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
ggml_cpy_q4_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, ggml_cpy_q4_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) { } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) { } else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
ggml_cpy_q5_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, ggml_cpy_q5_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) { } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) { } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) { } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) { } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
ggml_cpy_flt_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); ggml_cpy_flt_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
ggml_cpy_flt_cuda<half, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); ggml_cpy_flt_cuda<half, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) { } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) {
ggml_cpy_i32_i32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); // TODO consider converting to template
ggml_cpy_i32_i32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) { } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) { } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
ggml_cpy_flt_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); ggml_cpy_flt_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) { } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
ggml_cpy_flt_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); ggml_cpy_flt_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) { } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) {
ggml_cpy_flt_cuda<float, int32_t> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); ggml_cpy_flt_cuda<float, int32_t> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) { } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) {
ggml_cpy_flt_cuda<int32_t, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); ggml_cpy_flt_cuda<int32_t, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else { } else {
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__, GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
ggml_type_name(src0->type), ggml_type_name(src1->type)); ggml_type_name(src0->type), ggml_type_name(src1->type));
} }
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS)
if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection_for_this_node) {
ctx.cuda_graph->graph_cpynode_index = graph_cpynode_index;
}
#else
GGML_UNUSED(disable_indirection_for_this_node);
#endif
} }
void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src0 = dst->src[0];
bool disable_indirection = true; ggml_cuda_cpy(ctx, src0, dst);
ggml_cuda_cpy(ctx, src0, dst, disable_indirection);
}
void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
// Prioritize CUDA graph compatibility over direct memory copy optimization.
// Using copy kernels here maintains graph indirection support, preventing performance regression from disabled CUDA graphs.
if (src0->type == GGML_TYPE_F32) {
return (void*) cpy_flt<cpy_1_flt<float, float>>;
} else {
return nullptr;
}
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
return (void*) cpy_flt<cpy_1_flt<float, float>>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
return (void*) cpy_flt<cpy_1_flt<float, nv_bfloat16>>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
return (void*) cpy_flt<cpy_1_flt<float, half>>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
return (void*) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>;
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
return (void*) cpy_q_f32<cpy_blck_q8_0_f32, QK8_0>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
return (void*) cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>;
} else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
return (void*) cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>;
} else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
return (void*) cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>;
} else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
return (void*) cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
return (void*) cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>;
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1>;
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
return (void*) cpy_flt<cpy_1_flt<half, half>>;
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
return (void*) cpy_flt<cpy_1_flt<half, nv_bfloat16>>;
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
return (void*) cpy_flt<cpy_1_flt<half, float>>;
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, half>>;
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, nv_bfloat16>>;
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, float>>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) {
return (void*) cpy_flt<cpy_1_flt<float, int32_t>>;
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) {
return (void*) cpy_flt<cpy_1_flt<int32_t, float>>;
} else {
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
ggml_type_name(src0->type), ggml_type_name(src1->type));
}
} }

View File

@ -2,10 +2,6 @@
#define CUDA_CPY_BLOCK_SIZE 64 #define CUDA_CPY_BLOCK_SIZE 64
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection = false); void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1);
void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1);
void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_dest_ptrs, const int host_dest_ptrs_size, cudaStream_t stream);

View File

@ -895,6 +895,7 @@ void launch_fattn(
const dim3 block_dim(warp_size, nwarps, 1); const dim3 block_dim(warp_size, nwarps, 1);
int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy. int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy.
CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared)); CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));
GGML_ASSERT(max_blocks_per_sm > 0);
int parallel_blocks = max_blocks_per_sm; int parallel_blocks = max_blocks_per_sm;
dim3 blocks_num; dim3 blocks_num;

View File

@ -516,8 +516,8 @@ void ggml_cuda_flash_attn_ext_vec_case_impl(ggml_backend_cuda_context & ctx, ggm
const int nthreads = ggml_cuda_fattn_vec_get_nthreads_host(cc); const int nthreads = ggml_cuda_fattn_vec_get_nthreads_host(cc);
const int nwarps = nthreads / WARP_SIZE; const int nwarps = nthreads / WARP_SIZE;
fattn_kernel_t fattn_kernel = flash_attn_ext_vec<D, cols_per_block, type_K, type_V, use_logit_softcap>; fattn_kernel_t fattn_kernel = flash_attn_ext_vec<D, cols_per_block, type_K, type_V, use_logit_softcap>;
constexpr bool need_f16_K = false; const bool need_f16_K = type_K == GGML_TYPE_F16;
constexpr bool need_f16_V = false; const bool need_f16_V = type_V == GGML_TYPE_F16;
constexpr size_t nbytes_shared = 0; constexpr size_t nbytes_shared = 0;
launch_fattn<D, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false); launch_fattn<D, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
} }
@ -526,11 +526,6 @@ template <int D, ggml_type type_K, ggml_type type_V>
void ggml_cuda_flash_attn_ext_vec_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_flash_attn_ext_vec_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * KQV = dst; const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0]; const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];
GGML_ASSERT(K->type == type_K);
GGML_ASSERT(V->type == type_V);
float logit_softcap; float logit_softcap;
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));

View File

@ -117,10 +117,14 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
} }
#define FATTN_VEC_CASE(D, type_K, type_V) \ #define FATTN_VEC_CASE(D, type_K, type_V) \
if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \ { \
const bool type_K_okay = K->type == (type_K) || (K->type == GGML_TYPE_F32 && (type_K) == GGML_TYPE_F16); \
const bool type_V_okay = V->type == (type_V) || (V->type == GGML_TYPE_F32 && (type_V) == GGML_TYPE_F16); \
if (Q->ne[0] == (D) && type_K_okay && type_V_okay) { \
ggml_cuda_flash_attn_ext_vec_case<D, type_K, type_V>(ctx, dst); \ ggml_cuda_flash_attn_ext_vec_case<D, type_K, type_V>(ctx, dst); \
return; \ return; \
} \ } \
} \
#define FATTN_VEC_CASES_ALL_D(type_K, type_V) \ #define FATTN_VEC_CASES_ALL_D(type_K, type_V) \
FATTN_VEC_CASE( 64, type_K, type_V) \ FATTN_VEC_CASE( 64, type_K, type_V) \
@ -247,6 +251,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
#endif // GGML_CUDA_FA_ALL_QUANTS #endif // GGML_CUDA_FA_ALL_QUANTS
switch (K->type) { switch (K->type) {
case GGML_TYPE_F32:
case GGML_TYPE_F16: case GGML_TYPE_F16:
break; break;
case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_1:
@ -272,7 +277,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
// If Turing tensor cores available, use them: // If Turing tensor cores available, use them:
if (turing_mma_available(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40) { if (turing_mma_available(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40) {
if (can_use_vector_kernel) { if (can_use_vector_kernel) {
if (K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16) { if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne[1] == 1 && Q->ne[3] == 1 && !(gqa_ratio > 4 && K->ne[1] >= 8192)) { if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne[1] == 1 && Q->ne[3] == 1 && !(gqa_ratio > 4 && K->ne[1] >= 8192)) {
return BEST_FATTN_KERNEL_VEC; return BEST_FATTN_KERNEL_VEC;
} }
@ -305,7 +310,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
// If there are no tensor cores available, use the generic tile kernel: // If there are no tensor cores available, use the generic tile kernel:
if (can_use_vector_kernel) { if (can_use_vector_kernel) {
if (K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16) { if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
if (Q->ne[1] == 1) { if (Q->ne[1] == 1) {
if (!gqa_opt_applies) { if (!gqa_opt_applies) {
return BEST_FATTN_KERNEL_VEC; return BEST_FATTN_KERNEL_VEC;

View File

@ -2774,11 +2774,10 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
} }
#ifdef USE_CUDA_GRAPH #ifdef USE_CUDA_GRAPH
static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, static bool check_node_graph_compatibility(ggml_cgraph * cgraph,
int batch_size, bool use_cuda_graph) { int batch_size, bool use_cuda_graph) {
// Loop over nodes in GGML graph to obtain info needed for CUDA graph // Loop over nodes in GGML graph to obtain info needed for CUDA graph
cuda_ctx->cuda_graph->cpy_dest_ptrs.clear();
const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected"; const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected";
const std::string gemma3n_per_layer_proj_src1_name = "per_layer_proj"; const std::string gemma3n_per_layer_proj_src1_name = "per_layer_proj";
@ -2839,33 +2838,11 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
} }
} }
if (node->op == GGML_OP_CPY) {
// Store the pointers which are updated for each token, such that these can be sent
// to the device and accessed using indirection from CUDA graph
cuda_ctx->cuda_graph->cpy_dest_ptrs.push_back((char *) node->src[1]->data);
// store a pointer to each copy op CUDA kernel to identify it later
void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
if (!ptr) {
use_cuda_graph = false;
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported copy op\n", __func__);
#endif
}
}
if (!use_cuda_graph) { if (!use_cuda_graph) {
break; break;
} }
} }
if (use_cuda_graph) {
cuda_ctx->cuda_graph->use_cpy_indirection = true;
// copy pointers to GPU so they can be accessed via indirection within CUDA graph
ggml_cuda_cpy_dest_ptrs_copy(cuda_ctx->cuda_graph.get(), cuda_ctx->cuda_graph->cpy_dest_ptrs.data(), cuda_ctx->cuda_graph->cpy_dest_ptrs.size(), cuda_ctx->stream());
}
return use_cuda_graph; return use_cuda_graph;
} }
@ -2884,7 +2861,6 @@ static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_p
static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) { static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
if (node->data != graph_node_properties->node_address && if (node->data != graph_node_properties->node_address &&
node->op != GGML_OP_CPY &&
node->op != GGML_OP_VIEW) { node->op != GGML_OP_VIEW) {
return false; return false;
} }
@ -2905,7 +2881,6 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
for (int i = 0; i < GGML_MAX_SRC; i++) { for (int i = 0; i < GGML_MAX_SRC; i++) {
if (node->src[i] && if (node->src[i] &&
node->src[i]->data != graph_node_properties->src_address[i] && node->src[i]->data != graph_node_properties->src_address[i] &&
node->op != GGML_OP_CPY &&
node->op != GGML_OP_VIEW node->op != GGML_OP_VIEW
) { ) {
return false; return false;
@ -2985,18 +2960,15 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
#endif #endif
//TODO: remove special case once ggml_can_fuse can handle empty nodes //TODO: remove special case once ggml_can_fuse can handle empty nodes
std::initializer_list<enum ggml_op> topk_moe_ops = ggml_cuda_topk_moe_ops(false); std::initializer_list<enum ggml_op> topk_moe_ops =
std::initializer_list<enum ggml_op> topk_moe_ops_with_norm = ggml_cuda_topk_moe_ops(true); ggml_cuda_topk_moe_ops(/*with_norm*/ false, /*delayed_softmax=*/false);
std::initializer_list<enum ggml_op> topk_moe_ops_with_norm =
ggml_cuda_topk_moe_ops(/*with_norm=*/true, /*delayed_softmax=*/false);
std::initializer_list<enum ggml_op> topk_moe_ops_delayed_softmax =
ggml_cuda_topk_moe_ops(/*with_norm=*/false, /*delayed_softmax=*/true);
if (ops.size() == topk_moe_ops_with_norm.size() && std::equal(ops.begin(), ops.end(), topk_moe_ops_with_norm.begin())) { if (ops.size() == topk_moe_ops_with_norm.size() &&
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 8 })) {
if (node_idx + topk_moe_ops_with_norm.size() > (size_t)cgraph->n_nodes) {
return false;
}
for (size_t i = 0; i < topk_moe_ops_with_norm.size(); i++) {
if (cgraph->nodes[node_idx + i]->op != topk_moe_ops_with_norm.begin()[i]) return false;
}
ggml_tensor * softmax = cgraph->nodes[node_idx]; ggml_tensor * softmax = cgraph->nodes[node_idx];
ggml_tensor * weights = cgraph->nodes[node_idx+8]; ggml_tensor * weights = cgraph->nodes[node_idx+8];
@ -3005,16 +2977,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
} }
} }
if (ops.size() == topk_moe_ops.size() && std::equal(ops.begin(), ops.end(), topk_moe_ops.begin())) { if (ops.size() == topk_moe_ops.size() &&
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) {
if (node_idx + topk_moe_ops.size() > (size_t)cgraph->n_nodes) {
return false;
}
for (size_t i = 0; i < topk_moe_ops.size(); i++) {
if (cgraph->nodes[node_idx + i]->op != topk_moe_ops.begin()[i]) return false;
}
ggml_tensor * softmax = cgraph->nodes[node_idx]; ggml_tensor * softmax = cgraph->nodes[node_idx];
ggml_tensor * weights = cgraph->nodes[node_idx+4]; ggml_tensor * weights = cgraph->nodes[node_idx+4];
if (ggml_cuda_should_use_topk_moe(softmax, weights)) { if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
@ -3022,6 +2986,16 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
} }
} }
if (ops.size() == topk_moe_ops_delayed_softmax.size() &&
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 2, node_idx + 5 })) {
ggml_tensor * softmax = cgraph->nodes[node_idx + 4];
ggml_tensor * weights = cgraph->nodes[node_idx + 5];
if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
return true;
}
}
if (!ggml_can_fuse(cgraph, node_idx, ops)) { if (!ggml_can_fuse(cgraph, node_idx, ops)) {
return false; return false;
} }
@ -3052,7 +3026,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
} }
//if rms norm is the B operand, then we don't handle broadcast //if rms norm is the B operand, then we don't handle broadcast
if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) { if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm)) {
return false; return false;
} }
@ -3121,7 +3095,8 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ true), {})) { if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ true), {})) {
ggml_tensor * weights = cgraph->nodes[i+8]; ggml_tensor * weights = cgraph->nodes[i+8];
ggml_tensor * selected_experts = cgraph->nodes[i+3]; ggml_tensor * selected_experts = cgraph->nodes[i+3];
ggml_cuda_op_topk_moe(*cuda_ctx, node, weights, selected_experts, /*with norm*/ true); ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ true,
/*delayed softmax*/ false);
i += 8; i += 8;
continue; continue;
} }
@ -3129,11 +3104,23 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ false), {})) { if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ false), {})) {
ggml_tensor * weights = cgraph->nodes[i+4]; ggml_tensor * weights = cgraph->nodes[i+4];
ggml_tensor * selected_experts = cgraph->nodes[i+3]; ggml_tensor * selected_experts = cgraph->nodes[i+3];
ggml_cuda_op_topk_moe(*cuda_ctx, node, weights, selected_experts, /*with norm*/ false); ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ false,
/*delayed softmax*/ false);
i += 4; i += 4;
continue; continue;
} }
if (ggml_cuda_can_fuse(cgraph, i,
ggml_cuda_topk_moe_ops(/*with norm*/ false, /*delayed softmax*/ true), {})) {
ggml_tensor * weights = cgraph->nodes[i + 5];
ggml_tensor * ids = cgraph->nodes[i + 1];
ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, ids, /*with norm*/ false,
/*delayed_softmax*/ true);
i += 5;
continue;
}
if (node->op == GGML_OP_ADD) { if (node->op == GGML_OP_ADD) {
int n_fuse = 0; int n_fuse = 0;
ggml_op ops[8]; ggml_op ops[8];
@ -3278,7 +3265,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
if (use_cuda_graph) { if (use_cuda_graph) {
cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph); cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph);
use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops(cuda_ctx, cgraph, batch_size, use_cuda_graph); use_cuda_graph = check_node_graph_compatibility(cgraph, batch_size, use_cuda_graph);
// Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates. // Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
if (use_cuda_graph && cuda_graph_update_required) { if (use_cuda_graph && cuda_graph_update_required) {
@ -3305,10 +3292,6 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed)); CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
} }
if (!use_cuda_graph) {
cuda_ctx->cuda_graph->use_cpy_indirection = false;
}
#else #else
bool use_cuda_graph = false; bool use_cuda_graph = false;
bool cuda_graph_update_required = false; bool cuda_graph_update_required = false;
@ -3922,12 +3905,16 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_CONV_2D_DW: case GGML_OP_CONV_2D_DW:
case GGML_OP_CONV_TRANSPOSE_2D: case GGML_OP_CONV_TRANSPOSE_2D:
case GGML_OP_POOL_2D: case GGML_OP_POOL_2D:
case GGML_OP_SUM:
case GGML_OP_ACC: case GGML_OP_ACC:
return true; return true;
case GGML_OP_SUM:
return ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_ARGSORT: case GGML_OP_ARGSORT:
// TODO: Support arbitrary column width #ifndef GGML_CUDA_USE_CUB
return op->src[0]->ne[0] <= 1024; return op->src[0]->ne[0] <= 1024;
#else
return true;
#endif
case GGML_OP_SUM_ROWS: case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN: case GGML_OP_MEAN:
case GGML_OP_GROUP_NORM: case GGML_OP_GROUP_NORM:

View File

@ -1,5 +1,7 @@
#include "ggml.h" #include "ggml.h"
#include "mmf.cuh" #include "mmf.cuh"
#include "mmid.cuh"
void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) { void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
GGML_ASSERT( src1->type == GGML_TYPE_F32); GGML_ASSERT( src1->type == GGML_TYPE_F32);
@ -37,6 +39,12 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
const int64_t ids_s0 = ids ? ids->nb[0] / ggml_type_size(ids->type) : 0; const int64_t ids_s0 = ids ? ids->nb[0] / ggml_type_size(ids->type) : 0;
const int64_t ids_s1 = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0; const int64_t ids_s1 = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0;
mmf_ids_data ids_info{};
mmf_ids_data * ids_info_ptr = nullptr;
ggml_cuda_pool_alloc<int32_t> ids_src_compact_dev;
ggml_cuda_pool_alloc<int32_t> ids_dst_compact_dev;
ggml_cuda_pool_alloc<int32_t> expert_bounds_dev;
// For MUL_MAT_ID the memory layout is different than for MUL_MAT: // For MUL_MAT_ID the memory layout is different than for MUL_MAT:
const int64_t ncols_dst = ids ? ne2 : ne1; const int64_t ncols_dst = ids ? ne2 : ne1;
const int64_t nchannels_dst = ids ? ne1 : ne2; const int64_t nchannels_dst = ids ? ne1 : ne2;
@ -54,6 +62,33 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
nchannels_y = ids->ne[0]; nchannels_y = ids->ne[0];
} }
if (ids && ncols_dst > 16) {
const int64_t n_expert_used = ids->ne[0];
const int64_t n_experts = ne02;
const int64_t n_tokens = ne12;
const int64_t ne_get_rows = n_tokens * n_expert_used;
ids_src_compact_dev.alloc(ctx.pool(), ne_get_rows);
ids_dst_compact_dev.alloc(ctx.pool(), ne_get_rows);
expert_bounds_dev.alloc(ctx.pool(), n_experts + 1);
const int si1 = static_cast<int>(ids_s1);
const int sis1 = static_cast<int>(src1->nb[2] / src1->nb[1]);
GGML_ASSERT(sis1 > 0);
ggml_cuda_launch_mm_ids_helper(ids_d, ids_src_compact_dev.get(), ids_dst_compact_dev.get(), expert_bounds_dev.get(),
static_cast<int>(n_experts), static_cast<int>(n_tokens), static_cast<int>(n_expert_used), static_cast<int>(ne11), si1, sis1, ctx.stream());
CUDA_CHECK(cudaGetLastError());
ids_info.ids_src_compact = ids_src_compact_dev.get();
ids_info.ids_dst_compact = ids_dst_compact_dev.get();
ids_info.expert_bounds_dev = expert_bounds_dev.get();
ids_info.n_experts = static_cast<int>(n_experts);
ids_info.sis1 = sis1;
ids_info_ptr = &ids_info;
}
switch (src0->type) { switch (src0->type) {
case GGML_TYPE_F32: { case GGML_TYPE_F32: {
const float * src0_d = (const float *) src0->data; const float * src0_d = (const float *) src0->data;
@ -61,7 +96,7 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
mul_mat_f_switch_cols_per_block( mul_mat_f_switch_cols_per_block(
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream()); ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
} break; } break;
case GGML_TYPE_F16: { case GGML_TYPE_F16: {
const half2 * src0_d = (const half2 *) src0->data; const half2 * src0_d = (const half2 *) src0->data;
@ -69,7 +104,7 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
mul_mat_f_switch_cols_per_block( mul_mat_f_switch_cols_per_block(
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream()); ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
} break; } break;
case GGML_TYPE_BF16: { case GGML_TYPE_BF16: {
const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data; const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data;
@ -77,7 +112,7 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
mul_mat_f_switch_cols_per_block( mul_mat_f_switch_cols_per_block(
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream()); ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
} break; } break;
default: default:
GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type)); GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
@ -98,10 +133,9 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
} }
if (mul_mat_id) { if (mul_mat_id) {
if (type == GGML_TYPE_F32 && src1_ncols > 32) { if (src0_ne[1] <= 1024 && src1_ncols > 512) {
return false; return false;
} } else if(src0_ne[1] > 1024 && src1_ncols > 128) {
if ((type == GGML_TYPE_F16 || type == GGML_TYPE_BF16) && src1_ncols > 64) {
return false; return false;
} }
} else { } else {

View File

@ -7,6 +7,14 @@ using namespace ggml_cuda_mma;
#define MMF_ROWS_PER_BLOCK 32 #define MMF_ROWS_PER_BLOCK 32
struct mmf_ids_data {
const int32_t * ids_src_compact = nullptr;
const int32_t * ids_dst_compact = nullptr;
const int32_t * expert_bounds_dev = nullptr;
int n_experts = 0;
int sis1 = 0;
};
void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst); void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const int src1_ncols, bool mul_mat_id); bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const int src1_ncols, bool mul_mat_id);
@ -224,6 +232,250 @@ static __global__ void mul_mat_f(
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
} }
//This kernel is for larger batch sizes of mul_mat_id
template <typename T, int rows_per_block, int cols_per_block, int nwarps>
__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1)
static __global__ void mul_mat_f_ids(
const T * __restrict__ x, const float * __restrict__ y,
const int32_t * __restrict__ ids_src_compact, const int32_t * __restrict__ ids_dst_compact,
const int32_t * __restrict__ expert_bounds, float * __restrict__ dst,
const int ncols, const int ncols_dst_total, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst,
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
const uint3 sis1_fd, const uint3 nch_fd) {
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
typedef tile<16, 8, T> tile_A;
typedef tile< 8, 8, T> tile_B;
typedef tile<16, 8, float> tile_C;
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
constexpr int tile_k_padded = warp_size + 4;
constexpr int ntA = rows_per_block / tile_A::I;
constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I;
const int row0 = blockIdx.x * rows_per_block;
const int expert_idx = blockIdx.y;
const int expert_start = expert_bounds[expert_idx];
const int expert_end = expert_bounds[expert_idx + 1];
const int ncols_expert = expert_end - expert_start;
const int tiles_for_expert = (ncols_expert + cols_per_block - 1) / cols_per_block;
const int tile_idx = blockIdx.z;
if (tile_idx >= tiles_for_expert) {
return;
}
const int col_base = tile_idx * cols_per_block;
GGML_UNUSED(channel_ratio);
const int channel_x = expert_idx;
const int sample_dst = 0;
const int sample_x = sample_dst / sample_ratio;
const int sample_y = sample_dst;
x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row0*stride_row;
y += int64_t(sample_y) *stride_sample_y;
dst += int64_t(sample_dst)*stride_sample_dst;
const int32_t * ids_src_expert = ids_src_compact + expert_start;
const int32_t * ids_dst_expert = ids_dst_compact + expert_start;
extern __shared__ char data_mmv[];
char * compute_base = data_mmv;
//const float2 * y2 = (const float2 *) y;
tile_C C[ntA][ntB];
T * tile_xy = (T *) compute_base + threadIdx.y*(tile_A::I * tile_k_padded);
for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) {
tile_A A[ntA][warp_size / tile_A::J];
#pragma unroll
for (int itA = 0; itA < ntA; ++itA) {
#pragma unroll
for (int i = 0; i < tile_A::I; ++i) {
tile_xy[i*tile_k_padded + threadIdx.x] = x[(itA*tile_A::I + i)*stride_row + col];
}
#pragma unroll
for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) {
load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded);
}
}
if constexpr (std::is_same_v<T, float>) {
float vals_buf[2][tile_B::I];
auto gather_tile = [&](int tile_idx_local, float *vals) {
#pragma unroll
for (int j0 = 0; j0 < tile_B::I; ++j0) {
const int j = j0 + tile_idx_local*tile_B::I;
const int global_j = col_base + j;
float val = 0.0f;
if (j < cols_per_block && global_j < ncols_expert) {
const int src_entry = ids_src_expert[global_j];
const uint2 qrm = fast_div_modulo((uint32_t) src_entry, sis1_fd);
const int token = (int) qrm.x;
const int channel = (int) qrm.y;
if (token < ncols_dst_total) {
val = y[channel*stride_channel_y + token*stride_col_y + col];
}
}
vals[j0] = val;
}
};
gather_tile(0, vals_buf[0]);
int curr_buf = 0;
int next_buf = 1;
#pragma unroll
for (int itB = 0; itB < ntB; ++itB) {
#pragma unroll
for (int j0 = 0; j0 < tile_B::I; ++j0) {
tile_xy[j0*tile_k_padded + threadIdx.x] = vals_buf[curr_buf][j0];
}
if (itB + 1 < ntB) {
gather_tile(itB + 1, vals_buf[next_buf]);
}
#pragma unroll
for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) {
tile_B B;
load_ldmatrix(B, tile_xy + k0, tile_k_padded);
#pragma unroll
for (int itA = 0; itA < ntA; ++itA) {
mma(C[itA][itB], A[itA][k0/tile_B::J], B);
}
}
if (itB + 1 < ntB) {
curr_buf ^= 1;
next_buf ^= 1;
}
}
} else if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
float2 vals_buf[2][tile_B::I];
auto gather_tile = [&](int tile_idx_local, float2 *vals) {
#pragma unroll
for (int j0 = 0; j0 < tile_B::I; ++j0) {
const int j = j0 + tile_idx_local*tile_B::I;
const int global_j = col_base + j;
float2 tmp = make_float2(0.0f, 0.0f);
if (j < cols_per_block && global_j < ncols_expert) {
const int src_entry = ids_src_expert[global_j];
const uint2 qrm = fast_div_modulo((uint32_t) src_entry, sis1_fd);
const int token = (int) qrm.x;
const int channel = (int) qrm.y;
if (token < ncols_dst_total) {
tmp = *(const float2*) &y[channel*stride_channel_y + 2*(token*stride_col_y + col)];
}
}
vals[j0] = tmp;
}
};
if (ntB > 0) {
gather_tile(0, vals_buf[0]);
}
int curr_buf = 0;
int next_buf = 1;
#pragma unroll
for (int itB = 0; itB < ntB; ++itB) {
#pragma unroll
for (int j0 = 0; j0 < tile_B::I; ++j0) {
const float2 tmp = vals_buf[curr_buf][j0];
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
}
if (itB + 1 < ntB) {
gather_tile(itB + 1, vals_buf[next_buf]);
}
#pragma unroll
for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) {
tile_B B;
load_ldmatrix(B, tile_xy + k0, tile_k_padded);
#pragma unroll
for (int itA = 0; itA < ntA; ++itA) {
mma(C[itA][itB], A[itA][k0/tile_B::J], B);
}
}
if (itB + 1 < ntB) {
curr_buf ^= 1;
next_buf ^= 1;
}
}
} else {
static_assert(std::is_same_v<T, void>, "unsupported type");
}
}
float * buf_iw = (float *) compute_base;
constexpr int kiw = nwarps*rows_per_block + 4;
if (nwarps > 1) {
__syncthreads();
}
#pragma unroll
for (int itB = 0; itB < ntB; ++itB) {
#pragma unroll
for (int itA = 0; itA < ntA; ++itA) {
#pragma unroll
for (int l = 0; l < tile_C::ne; ++l) {
const int i = threadIdx.y*rows_per_block + itA*tile_C::I + tile_C::get_i(l);
const int j = itB*tile_C::J + tile_C::get_j(l);
buf_iw[j*kiw + i] = C[itA][itB].x[l];
}
}
}
if (nwarps > 1) {
__syncthreads();
}
#pragma unroll
for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) {
const int j = j0 + threadIdx.y;
if (j0 + nwarps > cols_per_block && j >= cols_per_block) {
return;
}
float sum = 0.0f;
static_assert(rows_per_block == warp_size, "need loop/check");
#pragma unroll
for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {
const int i = i0 + threadIdx.x;
sum += buf_iw[j*kiw + i];
}
const int global_j = col_base + j;
if (j < cols_per_block && global_j < ncols_expert && nchannels_dst > 0) {
const int dst_entry = ids_dst_expert[global_j];
const uint2 qrm = fast_div_modulo((uint32_t) dst_entry, nch_fd);
const int token = (int) qrm.x;
if (token < ncols_dst_total) {
const int slot = (int) qrm.y;
dst[slot*stride_channel_dst + token*stride_col_dst + row0 + threadIdx.x] = sum;
}
}
}
#else
GGML_UNUSED_VARS(x, y, ids_src_compact, ids_dst_compact, expert_bounds, dst,
ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd);
NO_DEVICE_CODE;
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
}
template<typename T, int cols_per_block, int nwarps> template<typename T, int cols_per_block, int nwarps>
static inline void mul_mat_f_switch_ids( static inline void mul_mat_f_switch_ids(
const T * x, const float * y, const int32_t * ids, float * dst, const T * x, const float * y, const int32_t * ids, float * dst,
@ -232,11 +484,33 @@ static inline void mul_mat_f_switch_ids(
const int64_t stride_col_id, const int64_t stride_row_id, const int64_t stride_col_id, const int64_t stride_row_id,
const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream) { const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream,
if (ids) { const mmf_ids_data * ids_data) {
const bool has_ids_data = ids_data && ids_data->ids_src_compact;
// Use the compact-ids kernel only for larger tiles; for small ncols_dst (< 16)
// we prefer the normal mul_mat_f path with has_ids=true.
if (has_ids_data && ncols_dst > 16) {
const int max_tiles = (int) ((ncols_dst + cols_per_block - 1) / cols_per_block);
if (max_tiles == 0) {
return;
}
dim3 block_nums_ids(block_nums.x, ids_data->n_experts, max_tiles);
const uint3 sis1_fd = ids_data->sis1 > 0 ? init_fastdiv_values((uint32_t) ids_data->sis1) : make_uint3(0, 0, 1);
const uint3 nch_fd = init_fastdiv_values((uint32_t) nchannels_dst);
mul_mat_f_ids<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>
(x, y, ids_data->ids_src_compact, ids_data->ids_dst_compact, ids_data->expert_bounds_dev, dst,
ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst,
sis1_fd, nch_fd);
} else if (ids) {
const int64_t col_tiles = (ncols_dst + cols_per_block - 1) / cols_per_block; const int64_t col_tiles = (ncols_dst + cols_per_block - 1) / cols_per_block;
dim3 block_nums_ids = block_nums; dim3 block_nums_ids = block_nums;
block_nums_ids.y *= col_tiles; block_nums_ids.y *= col_tiles;
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>> mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>
(x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, (x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
@ -258,7 +532,7 @@ void mul_mat_f_cuda(
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
cudaStream_t stream) { cudaStream_t stream, const mmf_ids_data * ids_data) {
typedef tile<16, 8, T> tile_A; typedef tile<16, 8, T> tile_A;
typedef tile< 8, 8, T> tile_B; typedef tile< 8, 8, T> tile_B;
@ -290,7 +564,7 @@ void mul_mat_f_cuda(
const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine); const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);
const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0; const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0;
const int nbytes_shared_total = nbytes_shared + nbytes_slotmap; const int nbytes_shared_total = nbytes_shared + nbytes_slotmap;
const int64_t grid_y = ids ? nchannels_x : nchannels_dst; // per expert when ids present const int64_t grid_y = ids ? nchannels_x : nchannels_dst;
const dim3 block_nums(nrows_x/rows_per_block, grid_y, nsamples_dst); const dim3 block_nums(nrows_x/rows_per_block, grid_y, nsamples_dst);
const dim3 block_dims(warp_size, nwarps_best, 1); const dim3 block_dims(warp_size, nwarps_best, 1);
@ -300,49 +574,57 @@ void mul_mat_f_cuda(
mul_mat_f_switch_ids<T, cols_per_block, 1>( mul_mat_f_switch_ids<T, cols_per_block, 1>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data);
} break; } break;
case 2: { case 2: {
mul_mat_f_switch_ids<T, cols_per_block, 2>( mul_mat_f_switch_ids<T, cols_per_block, 2>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data);
} break; } break;
case 3: { case 3: {
mul_mat_f_switch_ids<T, cols_per_block, 3>( mul_mat_f_switch_ids<T, cols_per_block, 3>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data);
} break; } break;
case 4: { case 4: {
mul_mat_f_switch_ids<T, cols_per_block, 4>( mul_mat_f_switch_ids<T, cols_per_block, 4>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data);
} break; } break;
case 5: { case 5: {
mul_mat_f_switch_ids<T, cols_per_block, 5>( mul_mat_f_switch_ids<T, cols_per_block, 5>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data);
} break; } break;
case 6: { case 6: {
mul_mat_f_switch_ids<T, cols_per_block, 6>( mul_mat_f_switch_ids<T, cols_per_block, 6>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data);
} break; } break;
case 7: { case 7: {
mul_mat_f_switch_ids<T, cols_per_block, 7>( mul_mat_f_switch_ids<T, cols_per_block, 7>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data);
} break; } break;
case 8: { case 8: {
mul_mat_f_switch_ids<T, cols_per_block, 8>( mul_mat_f_switch_ids<T, cols_per_block, 8>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data);
} break; } break;
default: { default: {
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
@ -361,7 +643,7 @@ static void mul_mat_f_switch_cols_per_block(
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
cudaStream_t stream) { cudaStream_t stream, const mmf_ids_data * ids_data) {
const int ncols_case = (ids && ncols_dst > 16) ? 16 : ncols_dst; const int ncols_case = (ids && ncols_dst > 16) ? 16 : ncols_dst;
@ -371,82 +653,82 @@ static void mul_mat_f_switch_cols_per_block(
case 1: { case 1: {
mul_mat_f_cuda<T, 1>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, mul_mat_f_cuda<T, 1>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break; } break;
case 2: { case 2: {
mul_mat_f_cuda<T, 2>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, mul_mat_f_cuda<T, 2>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break; } break;
case 3: { case 3: {
mul_mat_f_cuda<T, 3>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, mul_mat_f_cuda<T, 3>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break; } break;
case 4: { case 4: {
mul_mat_f_cuda<T, 4>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, mul_mat_f_cuda<T, 4>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break; } break;
case 5: { case 5: {
mul_mat_f_cuda<T, 5>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, mul_mat_f_cuda<T, 5>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break; } break;
case 6: { case 6: {
mul_mat_f_cuda<T, 6>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, mul_mat_f_cuda<T, 6>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break; } break;
case 7: { case 7: {
mul_mat_f_cuda<T, 7>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, mul_mat_f_cuda<T, 7>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break; } break;
case 8: { case 8: {
mul_mat_f_cuda<T, 8>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, mul_mat_f_cuda<T, 8>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break; } break;
case 9: { case 9: {
mul_mat_f_cuda<T, 9>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, mul_mat_f_cuda<T, 9>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break; } break;
case 10: { case 10: {
mul_mat_f_cuda<T, 10>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, mul_mat_f_cuda<T, 10>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break; } break;
case 11: { case 11: {
mul_mat_f_cuda<T, 11>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, mul_mat_f_cuda<T, 11>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break; } break;
case 12: { case 12: {
mul_mat_f_cuda<T, 12>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, mul_mat_f_cuda<T, 12>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break; } break;
case 13: { case 13: {
mul_mat_f_cuda<T, 13>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, mul_mat_f_cuda<T, 13>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break; } break;
case 14: { case 14: {
mul_mat_f_cuda<T, 14>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, mul_mat_f_cuda<T, 14>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break; } break;
case 15: { case 15: {
mul_mat_f_cuda<T, 15>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, mul_mat_f_cuda<T, 15>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break; } break;
case 16: { case 16: {
mul_mat_f_cuda<T, 16>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, mul_mat_f_cuda<T, 16>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break; } break;
default: { default: {
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
@ -462,7 +744,7 @@ static void mul_mat_f_switch_cols_per_block(
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, \ const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, \
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,\ const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,\
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, \ const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, \
cudaStream_t stream); cudaStream_t stream, const mmf_ids_data * ids_data);
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
#define DECL_MMF_CASE_EXTERN(ncols_dst) \ #define DECL_MMF_CASE_EXTERN(ncols_dst) \

View File

@ -0,0 +1,164 @@
#include "common.cuh"
#include "mmid.cuh"
// To reduce shared memory use, store "it" and "iex_used" with 22/10 bits each.
struct mm_ids_helper_store {
uint32_t data;
__device__ mm_ids_helper_store(const uint32_t it, const uint32_t iex_used) {
data = (it & 0x003FFFFF) | (iex_used << 22);
}
__device__ uint32_t it() const {
return data & 0x003FFFFF;
}
__device__ uint32_t iex_used() const {
return data >> 22;
}
};
static_assert(sizeof(mm_ids_helper_store) == 4, "unexpected size for mm_ids_helper_store");
// Helper function for mul_mat_id, converts ids to a more convenient format.
// ids_src1 describes how to permute the flattened column indices of src1 in order to get a compact src1 tensor sorted by expert.
// ids_dst describes the same mapping but for the dst tensor.
// The upper and lower bounds for the ith expert in the compact src1 tensor are stored in expert_bounds[i:i+1].
template <int n_expert_used_template>
__launch_bounds__(ggml_cuda_get_physical_warp_size(), 1)
static __global__ void mm_ids_helper(
const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1) {
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
const int n_expert_used = n_expert_used_template == 0 ? n_expert_used_var : n_expert_used_template;
const int expert = blockIdx.x;
extern __shared__ char data_mm_ids_helper[];
mm_ids_helper_store * store = (mm_ids_helper_store *) data_mm_ids_helper;
int nex_prev = 0; // Number of columns for experts with a lower index.
int it_compact = 0; // Running index for the compact slice of this expert.
if constexpr (n_expert_used_template == 0) {
// Generic implementation:
for (int it = 0; it < n_tokens; ++it) {
int iex_used = -1; // The index at which the expert is used, if any.
for (int iex = threadIdx.x; iex < n_expert_used; iex += warp_size) {
const int expert_used = ids[it*si1 + iex];
nex_prev += expert_used < expert;
if (expert_used == expert) {
iex_used = iex;
}
}
if (iex_used != -1) {
store[it_compact] = mm_ids_helper_store(it, iex_used);
}
if (warp_reduce_any<warp_size>(iex_used != -1)) {
it_compact++;
}
}
} else {
// Implementation optimized for specific numbers of experts used:
static_assert(n_expert_used == 6 || warp_size % n_expert_used == 0, "bad n_expert_used");
const int neu_padded = n_expert_used == 6 ? 8 : n_expert_used; // Padded to next higher power of 2.
for (int it0 = 0; it0 < n_tokens; it0 += warp_size/neu_padded) {
const int it = it0 + threadIdx.x / neu_padded;
const int iex = threadIdx.x % neu_padded; // The index at which the expert is used, if any.
const int expert_used = (neu_padded == n_expert_used || iex < n_expert_used) && it < n_tokens ?
ids[it*si1 + iex] : INT_MAX;
const int iex_used = expert_used == expert ? iex : -1;
nex_prev += expert_used < expert;
// Whether the threads at this token position have used the expert:
const int it_compact_add_self = warp_reduce_any<neu_padded>(iex_used != -1);
// Do a scan over threads at lower token positions in warp to get the correct index for writing data:
int it_compact_add_lower = 0;
#pragma unroll
for (int offset = neu_padded; offset < warp_size; offset += neu_padded) {
const int tmp = __shfl_up_sync(0xFFFFFFFF, it_compact_add_self, offset, warp_size);
if (threadIdx.x >= static_cast<unsigned int>(offset)) {
it_compact_add_lower += tmp;
}
}
if (iex_used != -1) {
store[it_compact + it_compact_add_lower] = mm_ids_helper_store(it, iex_used);
}
// The thread with the highest index in the warp always has the sum over the whole warp, use it to increment all threads:
it_compact += __shfl_sync(0xFFFFFFFF, it_compact_add_lower + it_compact_add_self, warp_size - 1, warp_size);
}
}
nex_prev = warp_reduce_sum<warp_size>(nex_prev);
for (int itc = threadIdx.x; itc < it_compact; itc += warp_size) {
const mm_ids_helper_store store_it = store[itc];
const int it = store_it.it();
const int iex_used = store_it.iex_used();
ids_src1[nex_prev + itc] = it*sis1 + iex_used % nchannels_y;
ids_dst [nex_prev + itc] = it*n_expert_used + iex_used;
}
if (threadIdx.x != 0) {
return;
}
expert_bounds[expert] = nex_prev;
if (expert < static_cast<int>(gridDim.x) - 1) {
return;
}
expert_bounds[gridDim.x] = nex_prev + it_compact;
}
template <int n_expert_used_template>
static void launch_mm_ids_helper(
const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
const int n_experts, const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) {
GGML_ASSERT(n_tokens < (1 << 22) && "too few bits in mm_ids_helper_store");
GGML_ASSERT(n_expert_used_var < (1 << 10) && "too few bits in mm_ids_helper_store");
const int id = ggml_cuda_get_device();
const int warp_size = ggml_cuda_info().devices[id].warp_size;
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
CUDA_SET_SHARED_MEMORY_LIMIT(mm_ids_helper<n_expert_used_template>, smpbo);
const dim3 num_blocks(n_experts, 1, 1);
const dim3 block_size(warp_size, 1, 1);
const size_t nbytes_shared = n_tokens*sizeof(mm_ids_helper_store);
GGML_ASSERT(nbytes_shared <= smpbo);
mm_ids_helper<n_expert_used_template><<<num_blocks, block_size, nbytes_shared, stream>>>
(ids, ids_src1, ids_dst, expert_bounds, n_tokens, n_expert_used_var, nchannels_y, si1, sis1);
}
void ggml_cuda_launch_mm_ids_helper(
const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
const int n_experts, const int n_tokens, const int n_expert_used, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) {
switch (n_expert_used) {
case 2:
launch_mm_ids_helper< 2>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
break;
case 4:
launch_mm_ids_helper< 4>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
break;
case 6:
launch_mm_ids_helper< 6>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
break;
case 8:
launch_mm_ids_helper< 8>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
break;
case 16:
launch_mm_ids_helper<16>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
break;
case 32:
launch_mm_ids_helper<32>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
break;
default:
launch_mm_ids_helper< 0>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
break;
}
}

View File

@ -0,0 +1,5 @@
#pragma once
void ggml_cuda_launch_mm_ids_helper(
const int32_t * ids, int32_t * ids_src1, int32_t * ids_dst, int32_t * expert_bounds,
int n_experts, int n_tokens, int n_expert_used, int nchannels_y, int si1, int sis1, cudaStream_t stream);

View File

@ -1,141 +1,6 @@
#include "mmq.cuh" #include "mmq.cuh"
#include "quantize.cuh" #include "quantize.cuh"
#include "mmid.cuh"
#include <vector>
// To reduce shared memory use, store "it" and "iex_used" with 22/10 bits each.
struct mmq_ids_helper_store {
uint32_t data;
__device__ mmq_ids_helper_store(const uint32_t it, const uint32_t iex_used) {
data = (it & 0x003FFFFF) | (iex_used << 22);
}
__device__ uint32_t it() const {
return data & 0x003FFFFF;
}
__device__ uint32_t iex_used() const {
return data >> 22;
}
};
static_assert(sizeof(mmq_ids_helper_store) == 4, "unexpected size for mmq_ids_helper_store");
// Helper function for mul_mat_id, converts ids to a more convenient format.
// ids_src1 describes how to permute the flattened column indices of src1 in order to get a compact src1 tensor sorted by expert.
// ids_dst describes the same mapping but for the dst tensor.
// The upper and lower bounds for the ith expert in the compact src1 tensor are stored in expert_bounds[i:i+1].
template <int n_expert_used_template>
__launch_bounds__(ggml_cuda_get_physical_warp_size(), 1)
static __global__ void mmq_ids_helper(
const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1) {
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
const int n_expert_used = n_expert_used_template == 0 ? n_expert_used_var : n_expert_used_template;
const int expert = blockIdx.x;
extern __shared__ char data_mmq_ids_helper[];
mmq_ids_helper_store * store = (mmq_ids_helper_store *) data_mmq_ids_helper;
int nex_prev = 0; // Number of columns for experts with a lower index.
int it_compact = 0; // Running index for the compact slice of this expert.
if constexpr (n_expert_used_template == 0) {
// Generic implementation:
for (int it = 0; it < n_tokens; ++it) {
int iex_used = -1; // The index at which the expert is used, if any.
for (int iex = threadIdx.x; iex < n_expert_used; iex += warp_size) {
const int expert_used = ids[it*si1 + iex];
nex_prev += expert_used < expert;
if (expert_used == expert) {
iex_used = iex;
}
}
if (iex_used != -1) {
store[it_compact] = mmq_ids_helper_store(it, iex_used);
}
if (warp_reduce_any<warp_size>(iex_used != -1)) {
it_compact++;
}
}
} else {
// Implementation optimized for specific numbers of experts used:
static_assert(n_expert_used == 6 || warp_size % n_expert_used == 0, "bad n_expert_used");
const int neu_padded = n_expert_used == 6 ? 8 : n_expert_used; // Padded to next higher power of 2.
for (int it0 = 0; it0 < n_tokens; it0 += warp_size/neu_padded) {
const int it = it0 + threadIdx.x / neu_padded;
const int iex = threadIdx.x % neu_padded; // The index at which the expert is used, if any.
const int expert_used = (neu_padded == n_expert_used || iex < n_expert_used) && it < n_tokens ?
ids[it*si1 + iex] : INT_MAX;
const int iex_used = expert_used == expert ? iex : -1;
nex_prev += expert_used < expert;
// Whether the threads at this token position have used the expert:
const int it_compact_add_self = warp_reduce_any<neu_padded>(iex_used != -1);
// Do a scan over threads at lower token positions in warp to get the correct index for writing data:
int it_compact_add_lower = 0;
#pragma unroll
for (int offset = neu_padded; offset < warp_size; offset += neu_padded) {
const int tmp = __shfl_up_sync(0xFFFFFFFF, it_compact_add_self, offset, warp_size);
if (threadIdx.x >= static_cast<unsigned int>(offset)) {
it_compact_add_lower += tmp;
}
}
if (iex_used != -1) {
store[it_compact + it_compact_add_lower] = mmq_ids_helper_store(it, iex_used);
}
// The thread with the highest index in the warp always has the sum over the whole warp, use it to increment all threads:
it_compact += __shfl_sync(0xFFFFFFFF, it_compact_add_lower + it_compact_add_self, warp_size - 1, warp_size);
}
}
nex_prev = warp_reduce_sum<warp_size>(nex_prev);
for (int itc = threadIdx.x; itc < it_compact; itc += warp_size) {
const mmq_ids_helper_store store_it = store[itc];
const int it = store_it.it();
const int iex_used = store_it.iex_used();
ids_src1[nex_prev + itc] = it*sis1 + iex_used % nchannels_y;
ids_dst [nex_prev + itc] = it*n_expert_used + iex_used;
}
if (threadIdx.x != 0) {
return;
}
expert_bounds[expert] = nex_prev;
if (expert < static_cast<int>(gridDim.x) - 1) {
return;
}
expert_bounds[gridDim.x] = nex_prev + it_compact;
}
template <int n_expert_used_template>
static void launch_mmq_ids_helper(
const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
const int n_experts, const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) {
GGML_ASSERT(n_tokens < (1 << 22) && "too few bits in mmq_ids_helper_store");
GGML_ASSERT(n_expert_used_var < (1 << 10) && "too few bits in mmq_ids_helper_store");
const int id = ggml_cuda_get_device();
const int warp_size = ggml_cuda_info().devices[id].warp_size;
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
CUDA_SET_SHARED_MEMORY_LIMIT(mmq_ids_helper<n_expert_used_template>, smpbo);
const dim3 num_blocks(n_experts, 1, 1);
const dim3 block_size(warp_size, 1, 1);
const size_t nbytes_shared = n_tokens*sizeof(mmq_ids_helper_store);
GGML_ASSERT(nbytes_shared <= smpbo);
mmq_ids_helper<n_expert_used_template><<<num_blocks, block_size, nbytes_shared, stream>>>
(ids, ids_src1, ids_dst, expert_bounds, n_tokens, n_expert_used_var, nchannels_y, si1, sis1);
}
static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) { static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
switch (args.type_x) { switch (args.type_x) {
@ -293,36 +158,8 @@ void ggml_cuda_mul_mat_q(
const int si1 = ids->nb[1] / ggml_element_size(ids); const int si1 = ids->nb[1] / ggml_element_size(ids);
const int sis1 = nb12 / nb11; const int sis1 = nb12 / nb11;
switch (n_expert_used) { ggml_cuda_launch_mm_ids_helper((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
case 2:
launch_mmq_ids_helper< 2> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
ne02, ne12, n_expert_used, ne11, si1, sis1, stream); ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
break;
case 4:
launch_mmq_ids_helper< 4> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
break;
case 6:
launch_mmq_ids_helper< 6> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
break;
case 8:
launch_mmq_ids_helper< 8> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
break;
case 16:
launch_mmq_ids_helper<16> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
break;
case 32:
launch_mmq_ids_helper<32> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
break;
default:
launch_mmq_ids_helper< 0> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
break;
}
CUDA_CHECK(cudaGetLastError()); CUDA_CHECK(cudaGetLastError());
} }

View File

@ -7,14 +7,14 @@ template <typename T, typename type_acc, int ncols_dst, int block_size>
static __global__ void mul_mat_vec_f( static __global__ void mul_mat_vec_f(
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst, const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst, const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
const int row = blockIdx.x; const int row = blockIdx.x;
const int channel_dst = blockIdx.y; const int channel_dst = blockIdx.y;
const int channel_x = ids ? ids[channel_dst] : channel_dst / channel_ratio; const int channel_x = ids ? ids[channel_dst] : fastdiv((uint32_t) channel_dst, channel_ratio);
const int channel_y = ids ? channel_dst % nchannels_y : channel_dst; const int channel_y = ids ? channel_dst % nchannels_y : channel_dst;
const int sample_dst = blockIdx.z; const int sample_dst = blockIdx.z;
const int sample_x = sample_dst / sample_ratio; const int sample_x = fastdiv((uint32_t) sample_dst, sample_ratio);
const int sample_y = sample_dst; const int sample_y = sample_dst;
const int tid = threadIdx.x; const int tid = threadIdx.x;
@ -47,8 +47,8 @@ static __global__ void mul_mat_vec_f(
#pragma unroll #pragma unroll
for (int j = 0; j < ncols_dst; ++j) { for (int j = 0; j < ncols_dst; ++j) {
const float2 tmpy = y2[j*stride_col_y2 + col2]; const float2 tmpy = y2[j*stride_col_y2 + col2];
sumf[j] += tmpx.x*tmpy.x; ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
sumf[j] += tmpx.y*tmpy.y; ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
} }
} }
} else if constexpr (std::is_same_v<T, half>) { } else if constexpr (std::is_same_v<T, half>) {
@ -61,8 +61,8 @@ static __global__ void mul_mat_vec_f(
#pragma unroll #pragma unroll
for (int j = 0; j < ncols_dst; ++j) { for (int j = 0; j < ncols_dst; ++j) {
const float2 tmpy = y2[j*stride_col_y2 + col2]; const float2 tmpy = y2[j*stride_col_y2 + col2];
sumf[j] += tmpx.x * tmpy.x; ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
sumf[j] += tmpx.y * tmpy.y; ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
} }
} }
} else { } else {
@ -88,16 +88,32 @@ static __global__ void mul_mat_vec_f(
#endif // FP16_AVAILABLE #endif // FP16_AVAILABLE
} }
} else if constexpr (std::is_same_v<T, nv_bfloat16>) { } else if constexpr (std::is_same_v<T, nv_bfloat16>) {
//TODO: add support for ggml_cuda_mad for hip_bfloat162
#if defined(GGML_USE_HIP)
const int * x2 = (const int *) x; const int * x2 = (const int *) x;
for (int col2 = tid; col2 < ncols2; col2 += block_size) { for (int col2 = tid; col2 < ncols2; col2 += block_size) {
const int tmpx = x2[col2]; const int tmpx = x2[col2];
#pragma unroll #pragma unroll
for (int j = 0; j < ncols_dst; ++j) { for (int j = 0; j < ncols_dst; ++j) {
const float2 tmpy = y2[j*stride_col_y2 + col2]; const float2 tmpy = y2[j*stride_col_y2 + col2];
sumf[j] += ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]) * tmpy.x; const float tmpx0 = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]);
sumf[j] += ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]) * tmpy.y; const float tmpx1 = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]);
ggml_cuda_mad(sumf[j], tmpx0, tmpy.x);
ggml_cuda_mad(sumf[j], tmpx1, tmpy.y);
} }
} }
#else
const nv_bfloat162 * x2 = (const nv_bfloat162 *) x;
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
const nv_bfloat162 tmpx = x2[col2];
#pragma unroll
for (int j = 0; j < ncols_dst; ++j) {
const float2 tmpy = y2[j*stride_col_y2 + col2];
ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
}
}
#endif
} else { } else {
static_assert(std::is_same_v<T, void>, "unsupported type"); static_assert(std::is_same_v<T, void>, "unsupported type");
} }
@ -140,8 +156,8 @@ static void launch_mul_mat_vec_f_cuda(
GGML_ASSERT(stride_col_y % 2 == 0); GGML_ASSERT(stride_col_y % 2 == 0);
GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0); GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
GGML_ASSERT( nsamples_dst % nsamples_x == 0); GGML_ASSERT( nsamples_dst % nsamples_x == 0);
const int64_t channel_ratio = nchannels_dst / nchannels_x; const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0) : init_fastdiv_values(nchannels_dst / nchannels_x);
const int64_t sample_ratio = nsamples_dst / nsamples_x; const uint3 sample_ratio_fd = init_fastdiv_values(nsamples_dst / nsamples_x);
const int device = ggml_cuda_get_device(); const int device = ggml_cuda_get_device();
const int warp_size = ggml_cuda_info().devices[device].warp_size; const int warp_size = ggml_cuda_info().devices[device].warp_size;
@ -167,50 +183,50 @@ static void launch_mul_mat_vec_f_cuda(
case 32: { case 32: {
mul_mat_vec_f<T, type_acc, ncols_dst, 32><<<block_nums, block_dims, nbytes_shared, stream>>> mul_mat_vec_f<T, type_acc, ncols_dst, 32><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
} break; } break;
case 64: { case 64: {
mul_mat_vec_f<T, type_acc, ncols_dst, 64><<<block_nums, block_dims, nbytes_shared, stream>>> mul_mat_vec_f<T, type_acc, ncols_dst, 64><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
} break; } break;
case 96: { case 96: {
mul_mat_vec_f<T, type_acc, ncols_dst, 96><<<block_nums, block_dims, nbytes_shared, stream>>> mul_mat_vec_f<T, type_acc, ncols_dst, 96><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
} break; } break;
case 128: { case 128: {
mul_mat_vec_f<T, type_acc, ncols_dst, 128><<<block_nums, block_dims, nbytes_shared, stream>>> mul_mat_vec_f<T, type_acc, ncols_dst, 128><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
} break; } break;
case 160: { case 160: {
mul_mat_vec_f<T, type_acc, ncols_dst, 160><<<block_nums, block_dims, nbytes_shared, stream>>> mul_mat_vec_f<T, type_acc, ncols_dst, 160><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
} break; } break;
case 192: { case 192: {
mul_mat_vec_f<T, type_acc, ncols_dst, 192><<<block_nums, block_dims, nbytes_shared, stream>>> mul_mat_vec_f<T, type_acc, ncols_dst, 192><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
} break; } break;
case 224: { case 224: {
mul_mat_vec_f<T, type_acc, ncols_dst, 224><<<block_nums, block_dims, nbytes_shared, stream>>> mul_mat_vec_f<T, type_acc, ncols_dst, 224><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
} break; } break;
case 256: { case 256: {
mul_mat_vec_f<T, type_acc, ncols_dst, 256><<<block_nums, block_dims, nbytes_shared, stream>>> mul_mat_vec_f<T, type_acc, ncols_dst, 256><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
} break; } break;
default: { default: {
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");

View File

@ -4,16 +4,61 @@
#include <initializer_list> #include <initializer_list>
// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path.
template <int experts_per_thread, bool use_limit>
__device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const int limit, const int lane) {
float max_val = -INFINITY;
#pragma unroll
for (int i = 0; i < experts_per_thread; i++) {
const int idx = lane + i * WARP_SIZE;
const bool active = !use_limit || (idx < limit);
if (active) {
max_val = max(max_val, vals[i]);
}
}
max_val = warp_reduce_max(max_val);
float sum = 0.f;
#pragma unroll
for (int i = 0; i < experts_per_thread; i++) {
const int idx = lane + i * WARP_SIZE;
const bool active = !use_limit || (idx < limit);
if (active) {
const float val = expf(vals[i] - max_val);
vals[i] = val;
sum += val;
} else {
vals[i] = 0.f;
}
}
sum = warp_reduce_sum(sum);
const float inv_sum = 1.0f / sum;
#pragma unroll
for (int i = 0; i < experts_per_thread; i++) {
const int idx = lane + i * WARP_SIZE;
const bool active = !use_limit || (idx < limit);
if (active) {
vals[i] *= inv_sum;
}
}
}
/* /*
This kernel does the following: This kernel does the following:
1. softmax over the logits per token [n_experts, n_tokens] 1. optionally softmax over the logits per token [n_experts, n_tokens]
2. argmax reduce over the top-k (n_experts_used) logits 2. argmax reduce over the top-k (n_experts_used) logits
3. write weights + ids to global memory 3. write weights + ids to global memory
4. optionally normalize the weights 4. optionally normalize the weights or apply softmax over the selected logits
It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models
*/ */
template <int n_experts, bool with_norm> template <int n_experts, bool with_norm, bool delayed_softmax = false>
__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits, __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits,
float * weights, float * weights,
int32_t * ids, int32_t * ids,
@ -30,51 +75,30 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
constexpr int experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1; constexpr int experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1;
float logits_r[experts_per_thread]; float wt[experts_per_thread];
#pragma unroll #pragma unroll
for (int i = 0; i < n_experts; i += WARP_SIZE) { for (int i = 0; i < n_experts; i += WARP_SIZE) {
const int expert = i + threadIdx.x; const int expert = i + threadIdx.x;
logits_r[i / WARP_SIZE] = n_experts % WARP_SIZE == 0 || expert < n_experts ? logits[expert] : -INFINITY; wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[expert] : -INFINITY;
} }
float max_val = logits_r[0]; if constexpr (!delayed_softmax) {
softmax_warp_inplace<experts_per_thread, false>(wt, n_experts, threadIdx.x);
#pragma unroll
for (int i = 1; i < experts_per_thread; i++) {
const float val = logits_r[i];
max_val = max(val, max_val);
} }
max_val = warp_reduce_max(max_val); //at this point, each thread holds either a portion of the softmax distribution
//or the raw logits. We do the argmax reduce over n_expert_used, each time marking
float wt[experts_per_thread];
float tmp = 0.f;
#pragma unroll
for (int i = 0; i < experts_per_thread; i++) {
const float val = logits_r[i];
wt[i] = expf(val - max_val);
tmp += wt[i];
}
tmp = warp_reduce_sum(tmp);
const float inv_sum = 1.0f / tmp;
#pragma unroll
for (int i = 0; i < experts_per_thread; i++) {
wt[i] = wt[i] * inv_sum;
}
//at this point, each thread holds a portion of softmax,
//we do the argmax reduce over n_expert_used, each time marking
//the expert weight as -inf to exclude from the next iteration //the expert weight as -inf to exclude from the next iteration
float wt_sum = 0.f; float wt_sum = 0.f;
extern __shared__ float data_topk_shared[]; float output_weights[experts_per_thread];
float * wt_shared_ptr = data_topk_shared + threadIdx.y * n_expert_used;
#pragma unroll
for (int i = 0; i < experts_per_thread; i++) {
output_weights[i] = 0.f;
}
for (int k = 0; k < n_expert_used; k++) { for (int k = 0; k < n_expert_used; k++) {
float max_val = wt[0]; float max_val = wt[0];
@ -99,10 +123,13 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
} }
} }
if ((k & (WARP_SIZE - 1)) == threadIdx.x) {
output_weights[k / WARP_SIZE] = max_val;
}
if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) { if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
wt[max_expert / WARP_SIZE] = -INFINITY; wt[max_expert / WARP_SIZE] = -INFINITY;
wt_shared_ptr[k] = max_val;
ids[k] = max_expert; ids[k] = max_expert;
if constexpr (with_norm) { if constexpr (with_norm) {
wt_sum += max_val; wt_sum += max_val;
@ -114,17 +141,25 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
wt_sum = warp_reduce_sum(wt_sum); wt_sum = warp_reduce_sum(wt_sum);
const float inv_sum = 1.0f / wt_sum; const float inv_sum = 1.0f / wt_sum;
for (int i = threadIdx.x; i < n_expert_used; i += WARP_SIZE) { for (int i = 0; i < experts_per_thread; i++) {
wt_shared_ptr[i] = wt_shared_ptr[i] * inv_sum; output_weights[i] *= inv_sum;
} }
} }
for (int i = threadIdx.x; i < n_expert_used; i += WARP_SIZE) { if constexpr (delayed_softmax) {
weights[i] = wt_shared_ptr[i]; softmax_warp_inplace<experts_per_thread, true>(output_weights, n_expert_used, threadIdx.x);
}
#pragma unroll
for (int i = 0; i < experts_per_thread; i++) {
const int idx = i * WARP_SIZE + threadIdx.x;
if (idx < n_expert_used) {
weights[idx] = output_weights[i];
}
} }
} }
template <bool with_norm> template <bool with_norm, bool delayed_softmax = false>
static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx, static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
const float * logits, const float * logits,
float * weights, float * weights,
@ -132,53 +167,53 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
const int n_rows, const int n_rows,
const int n_expert, const int n_expert,
const int n_expert_used) { const int n_expert_used) {
static_assert(!(with_norm && delayed_softmax), "delayed softmax is not supported with weight normalization");
const int rows_per_block = 4; const int rows_per_block = 4;
dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1); dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1);
dim3 block_dims(WARP_SIZE, rows_per_block, 1); dim3 block_dims(WARP_SIZE, rows_per_block, 1);
cudaStream_t stream = ctx.stream(); cudaStream_t stream = ctx.stream();
const int nbytes_shared = n_expert_used * rows_per_block * sizeof(float);
switch (n_expert) { switch (n_expert) {
case 1: case 1:
topk_moe_cuda<1, with_norm> topk_moe_cuda<1, with_norm, delayed_softmax>
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used); <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break; break;
case 2: case 2:
topk_moe_cuda<2, with_norm> topk_moe_cuda<2, with_norm, delayed_softmax>
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used); <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break; break;
case 4: case 4:
topk_moe_cuda<4, with_norm> topk_moe_cuda<4, with_norm, delayed_softmax>
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used); <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break; break;
case 8: case 8:
topk_moe_cuda<8, with_norm> topk_moe_cuda<8, with_norm, delayed_softmax>
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used); <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break; break;
case 16: case 16:
topk_moe_cuda<16, with_norm> topk_moe_cuda<16, with_norm, delayed_softmax>
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used); <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break; break;
case 32: case 32:
topk_moe_cuda<32, with_norm> topk_moe_cuda<32, with_norm, delayed_softmax>
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used); <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break; break;
case 64: case 64:
topk_moe_cuda<64, with_norm> topk_moe_cuda<64, with_norm, delayed_softmax>
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used); <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break; break;
case 128: case 128:
topk_moe_cuda<128, with_norm> topk_moe_cuda<128, with_norm, delayed_softmax>
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used); <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break; break;
case 256: case 256:
topk_moe_cuda<256, with_norm> topk_moe_cuda<256, with_norm, delayed_softmax>
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used); <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break; break;
case 512: case 512:
topk_moe_cuda<512, with_norm> topk_moe_cuda<512, with_norm, delayed_softmax>
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used); <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break; break;
default: default:
GGML_ASSERT(false && "fatal error"); GGML_ASSERT(false && "fatal error");
@ -190,7 +225,8 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
const ggml_tensor * logits, const ggml_tensor * logits,
ggml_tensor * weights, ggml_tensor * weights,
ggml_tensor * ids, ggml_tensor * ids,
const bool with_norm) { const bool with_norm,
const bool delayed_softmax) {
GGML_ASSERT(logits->type == GGML_TYPE_F32); GGML_ASSERT(logits->type == GGML_TYPE_F32);
GGML_ASSERT(weights->type == GGML_TYPE_F32); GGML_ASSERT(weights->type == GGML_TYPE_F32);
GGML_ASSERT(ids->type == GGML_TYPE_I32); GGML_ASSERT(ids->type == GGML_TYPE_I32);
@ -198,7 +234,7 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
const int n_experts = logits->ne[0]; const int n_experts = logits->ne[0];
const int n_rows = logits->ne[1]; const int n_rows = logits->ne[1];
const float * logits_d = (const float *) logits->src[0]->data; const float * logits_d = (const float *) logits->data;
float * weights_d = (float *) weights->data; float * weights_d = (float *) weights->data;
int32_t * ids_d = (int32_t *) ids->data; int32_t * ids_d = (int32_t *) ids->data;
@ -209,7 +245,11 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
if (with_norm) { if (with_norm) {
launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used); launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
} else { } else {
launch_topk_moe_cuda<false>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used); if (delayed_softmax) {
launch_topk_moe_cuda<false, true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
} else {
launch_topk_moe_cuda<false, false>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
}
} }
} }
@ -242,7 +282,7 @@ bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tenso
return true; return true;
} }
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool norm) { std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool norm, bool delayed_softmax) {
static std::initializer_list<enum ggml_op> norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, static std::initializer_list<enum ggml_op> norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE, GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
GGML_OP_SUM_ROWS, GGML_OP_DIV, GGML_OP_RESHAPE }; GGML_OP_SUM_ROWS, GGML_OP_DIV, GGML_OP_RESHAPE };
@ -250,8 +290,19 @@ std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool norm) {
static std::initializer_list<enum ggml_op> no_norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, static std::initializer_list<enum ggml_op> no_norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
GGML_OP_VIEW, GGML_OP_GET_ROWS }; GGML_OP_VIEW, GGML_OP_GET_ROWS };
static std::initializer_list<enum ggml_op> delayed_softmax_ops = { GGML_OP_ARGSORT, GGML_OP_VIEW,
GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
GGML_ASSERT(!norm || !delayed_softmax);
if (delayed_softmax) {
return delayed_softmax_ops;
}
if (norm) { if (norm) {
return norm_ops; return norm_ops;
} }
return no_norm_ops; return no_norm_ops;
} }

View File

@ -6,9 +6,10 @@
void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
const ggml_tensor * logits, const ggml_tensor * logits,
ggml_tensor * weights, ggml_tensor * weights,
ggml_tensor * top_k, ggml_tensor * ids,
const bool with_norm); const bool with_norm,
const bool delayed_softmax = false);
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights); bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights);
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool with_norm); std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool with_norm, bool delayed_softmax = false);

View File

@ -28,8 +28,10 @@ if (CXX_IS_HIPCC)
" Prefer setting the HIP compiler directly. See README for details.") " Prefer setting the HIP compiler directly. See README for details.")
endif() endif()
else() else()
# Forward AMDGPU_TARGETS to CMAKE_HIP_ARCHITECTURES. # Forward (AMD)GPU_TARGETS to CMAKE_HIP_ARCHITECTURES.
if (AMDGPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES) if(GPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES)
set(CMAKE_HIP_ARCHITECTURES ${GPU_TARGETS})
elseif(AMDGPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES)
set(CMAKE_HIP_ARCHITECTURES ${AMDGPU_TARGETS}) set(CMAKE_HIP_ARCHITECTURES ${AMDGPU_TARGETS})
endif() endif()
cmake_minimum_required(VERSION 3.21) cmake_minimum_required(VERSION 3.21)

View File

@ -565,14 +565,23 @@ static inline ggml_bf16_t ggml_compute_fp32_to_bf16(float s) {
#define GGML_FP32_TO_BF16(x) ggml_compute_fp32_to_bf16(x) #define GGML_FP32_TO_BF16(x) ggml_compute_fp32_to_bf16(x)
#define GGML_BF16_TO_FP32(x) ggml_compute_bf16_to_fp32(x) #define GGML_BF16_TO_FP32(x) ggml_compute_bf16_to_fp32(x)
static inline int32_t ggml_node_get_use_count(const struct ggml_cgraph * cgraph, int node_idx) {
const struct ggml_tensor * node = cgraph->nodes[node_idx];
size_t hash_pos = ggml_hash_find(&cgraph->visited_hash_set, node);
if (!ggml_bitset_get(cgraph->visited_hash_set.used, hash_pos)) {
return 0;
}
return cgraph->use_counts[hash_pos];
}
// return true if the node's results are only used by N other nodes // return true if the node's results are only used by N other nodes
// and can be fused into their calculations. // and can be fused into their calculations.
static inline bool ggml_node_has_n_uses(const struct ggml_cgraph * cgraph, int node_idx, int32_t n_uses) { static inline bool ggml_node_has_n_uses(const struct ggml_cgraph * cgraph, int node_idx, int32_t n_uses) {
const struct ggml_tensor * node = cgraph->nodes[node_idx]; const struct ggml_tensor * node = cgraph->nodes[node_idx];
// check the use count against how many we're replacing // check the use count against how many we're replacing
size_t hash_pos = ggml_hash_find(&cgraph->visited_hash_set, node); if (ggml_node_get_use_count(cgraph, node_idx) != n_uses) {
if (!ggml_bitset_get(cgraph->visited_hash_set.used, hash_pos) || cgraph->use_counts[hash_pos] != n_uses) {
return false; return false;
} }
@ -638,6 +647,36 @@ static inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx
return ggml_can_fuse_ext(cgraph, idxs, ops, num_ops); return ggml_can_fuse_ext(cgraph, idxs, ops, num_ops);
} }
GGML_API bool ggml_can_fuse_subgraph_ext(const struct ggml_cgraph * cgraph,
const int * node_idxs,
int count,
const enum ggml_op * ops,
const int * outputs,
int num_outputs);
// Returns true if the subgraph formed by {node_idxs} can be fused
// checks whethers all nodes which are not part of outputs can be elided
// by checking if their num_uses are confined to the subgraph
static inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph,
int node_idx,
int count,
const enum ggml_op * ops,
const int * outputs,
int num_outputs) {
GGML_ASSERT(count < 32);
if (node_idx + count > cgraph->n_nodes) {
return false;
}
int idxs[32];
for (int i = 0; i < count; ++i) {
idxs[i] = node_idx + i;
}
return ggml_can_fuse_subgraph_ext(cgraph, idxs, count, ops, outputs, num_outputs);
}
// Management libraries for fetching more accurate free VRAM data // Management libraries for fetching more accurate free VRAM data
GGML_API int ggml_nvml_init(); GGML_API int ggml_nvml_init();
GGML_API int ggml_nvml_get_device_memory(const char *uuid, size_t *free, size_t *total); GGML_API int ggml_nvml_get_device_memory(const char *uuid, size_t *free, size_t *total);
@ -662,6 +701,13 @@ inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::
return ggml_can_fuse(cgraph, node_idx, ops.begin(), (int)ops.size()); return ggml_can_fuse(cgraph, node_idx, ops.begin(), (int)ops.size());
} }
inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph,
int start_idx,
std::initializer_list<enum ggml_op> ops,
std::initializer_list<int> outputs = {}) {
return ggml_can_fuse_subgraph(cgraph, start_idx, ops.size(), ops.begin(), outputs.begin(), outputs.size());
}
// expose GGUF internals for test code // expose GGUF internals for test code
GGML_API size_t gguf_type_size(enum gguf_type type); GGML_API size_t gguf_type_size(enum gguf_type type);
GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params); GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params);

View File

@ -1406,6 +1406,31 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d(ggml_met
return res; return res;
} }
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_2d(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_CONV_TRANSPOSE_2D);
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
GGML_ASSERT(ggml_is_contiguous(op->src[1]));
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
GGML_ASSERT(op->type == GGML_TYPE_F32);
char base[256];
char name[256];
snprintf(base, 256, "kernel_conv_transpose_2d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type));
snprintf(name, 256, "%s", base);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
return res;
}
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
return res;
}
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_upscale(ggml_metal_library_t lib, const ggml_tensor * op) { ggml_metal_pipeline_t ggml_metal_library_get_pipeline_upscale(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_UPSCALE); assert(op->op == GGML_OP_UPSCALE);

View File

@ -130,6 +130,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm (ggml_me
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_im2col (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_im2col (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_2d (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_upscale (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_upscale (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);

View File

@ -7,6 +7,8 @@
#include <Metal/Metal.h> #include <Metal/Metal.h>
#include <stdatomic.h>
#ifndef TARGET_OS_VISION #ifndef TARGET_OS_VISION
#define TARGET_OS_VISION 0 #define TARGET_OS_VISION 0
#endif #endif
@ -22,6 +24,9 @@
// overload of MTLGPUFamilyMetal3 (not available in some environments) // overload of MTLGPUFamilyMetal3 (not available in some environments)
static const NSInteger MTLGPUFamilyMetal3_GGML = 5001; static const NSInteger MTLGPUFamilyMetal3_GGML = 5001;
// virtual address for GPU memory allocations
static atomic_uintptr_t g_addr_device = 0x000000400ULL;
#if !GGML_METAL_EMBED_LIBRARY #if !GGML_METAL_EMBED_LIBRARY
// Here to assist with NSBundle Path Hack // Here to assist with NSBundle Path Hack
@interface GGMLMetalClass : NSObject @interface GGMLMetalClass : NSObject
@ -648,6 +653,11 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_OP_SCALE: case GGML_OP_SCALE:
case GGML_OP_CONV_TRANSPOSE_1D: case GGML_OP_CONV_TRANSPOSE_1D:
return true; return true;
case GGML_OP_CONV_TRANSPOSE_2D:
return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]) &&
(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32) &&
op->src[1]->type == GGML_TYPE_F32 &&
op->type == GGML_TYPE_F32;
case GGML_OP_CLAMP: case GGML_OP_CLAMP:
return op->src[0]->type == GGML_TYPE_F32; return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_SQR: case GGML_OP_SQR:
@ -657,6 +667,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_OP_LOG: case GGML_OP_LOG:
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_SUM: case GGML_OP_SUM:
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
case GGML_OP_SUM_ROWS: case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN: case GGML_OP_MEAN:
case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX:
@ -693,7 +704,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
return true; return true;
case GGML_OP_FLASH_ATTN_EXT: case GGML_OP_FLASH_ATTN_EXT:
// for new head sizes, add checks here // for new head sizes, add checks here
if (op->src[0]->ne[0] != 40 && if (op->src[0]->ne[0] != 32 &&
op->src[0]->ne[0] != 40 &&
op->src[0]->ne[0] != 64 && op->src[0]->ne[0] != 64 &&
op->src[0]->ne[0] != 80 && op->src[0]->ne[0] != 80 &&
op->src[0]->ne[0] != 96 && op->src[0]->ne[0] != 96 &&
@ -826,7 +838,7 @@ struct ggml_metal_buffer_wrapper {
}; };
struct ggml_metal_buffer { struct ggml_metal_buffer {
void * all_data; // TODO: https://github.com/ggml-org/llama.cpp/pull/15985 void * all_data;
size_t all_size; size_t all_size;
// if false, the Metal buffer data is allocated in private GPU memory and is not shared with the host // if false, the Metal buffer data is allocated in private GPU memory and is not shared with the host
@ -964,14 +976,15 @@ ggml_metal_buffer_t ggml_metal_buffer_init(ggml_metal_device_t dev, size_t size,
if (shared) { if (shared) {
res->all_data = ggml_metal_host_malloc(size_aligned); res->all_data = ggml_metal_host_malloc(size_aligned);
res->is_shared = true; res->is_shared = true;
res->owned = true;
} else { } else {
// dummy, non-NULL value - we'll populate this after creating the Metal buffer below // use virtual address from g_addr_device counter
res->all_data = (void *) 0x000000400ULL; res->all_data = (void *) atomic_fetch_add_explicit(&g_addr_device, size_aligned, memory_order_relaxed);
res->is_shared = false; res->is_shared = false;
} }
res->all_size = size_aligned; res->all_size = size_aligned;
res->owned = true;
res->device = ggml_metal_device_get_obj(dev); res->device = ggml_metal_device_get_obj(dev);
res->queue = ggml_metal_device_get_queue(dev); res->queue = ggml_metal_device_get_queue(dev);
@ -982,15 +995,13 @@ ggml_metal_buffer_t ggml_metal_buffer_init(ggml_metal_device_t dev, size_t size,
res->buffers[0].metal = nil; res->buffers[0].metal = nil;
if (size_aligned > 0) { if (size_aligned > 0) {
if (props_dev->use_shared_buffers &&shared) { if (props_dev->use_shared_buffers && shared) {
res->buffers[0].metal = [res->device newBufferWithBytesNoCopy:res->all_data res->buffers[0].metal = [res->device newBufferWithBytesNoCopy:res->all_data
length:size_aligned length:size_aligned
options:MTLResourceStorageModeShared options:MTLResourceStorageModeShared
deallocator:nil]; deallocator:nil];
} else { } else {
res->buffers[0].metal = [res->device newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate]; res->buffers[0].metal = [res->device newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate];
res->all_data = (void *) (res->buffers[0].metal.gpuAddress);
} }
} }
@ -1138,7 +1149,7 @@ bool ggml_metal_buffer_is_shared(ggml_metal_buffer_t buf) {
void ggml_metal_buffer_memset_tensor(ggml_metal_buffer_t buf, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { void ggml_metal_buffer_memset_tensor(ggml_metal_buffer_t buf, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
if (buf->is_shared) { if (buf->is_shared) {
memset((char *)tensor->data + offset, value, size); memset((char *) tensor->data + offset, value, size);
return; return;
} }
@ -1167,7 +1178,7 @@ void ggml_metal_buffer_memset_tensor(ggml_metal_buffer_t buf, struct ggml_tensor
void ggml_metal_buffer_set_tensor(ggml_metal_buffer_t buf, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { void ggml_metal_buffer_set_tensor(ggml_metal_buffer_t buf, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
if (buf->is_shared) { if (buf->is_shared) {
memcpy((char *)tensor->data + offset, data, size); memcpy((char *) tensor->data + offset, data, size);
return; return;
} }
@ -1222,7 +1233,7 @@ void ggml_metal_buffer_set_tensor(ggml_metal_buffer_t buf, struct ggml_tensor *
void ggml_metal_buffer_get_tensor(ggml_metal_buffer_t buf, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { void ggml_metal_buffer_get_tensor(ggml_metal_buffer_t buf, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
if (buf->is_shared) { if (buf->is_shared) {
memcpy(data, (const char *)tensor->data + offset, size); memcpy(data, (const char *) tensor->data + offset, size);
return; return;
} }

View File

@ -2136,6 +2136,7 @@ typedef struct {
int32_t sect_1; int32_t sect_1;
int32_t sect_2; int32_t sect_2;
int32_t sect_3; int32_t sect_3;
bool src2;
} ggml_metal_kargs_rope; } ggml_metal_kargs_rope;
typedef struct { typedef struct {
@ -2398,6 +2399,19 @@ typedef struct {
uint64_t nb1; uint64_t nb1;
} ggml_metal_kargs_conv_transpose_1d; } ggml_metal_kargs_conv_transpose_1d;
typedef struct {
int32_t IC;
int32_t IH;
int32_t IW;
int32_t KH;
int32_t KW;
int32_t OC;
int32_t s0;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
} ggml_metal_kargs_conv_transpose_2d;
typedef struct { typedef struct {
uint64_t ofs0; uint64_t ofs0;
uint64_t ofs1; uint64_t ofs1;
@ -4392,18 +4406,48 @@ kernel void kernel_op_sum_f32(
constant ggml_metal_kargs_sum & args, constant ggml_metal_kargs_sum & args,
device const float * src0, device const float * src0,
device float * dst, device float * dst,
ushort tiitg[[thread_index_in_threadgroup]]) { threadgroup float * shmem_f32 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
if (tiitg != 0) { if (args.np == 0) {
return; return;
} }
float acc = 0.0f; const uint nsg = (ntg.x + 31) / 32;
for (ulong i = 0; i < args.np; ++i) {
acc += src0[i]; float sumf = 0;
for (int64_t i0 = tpitg.x; i0 < args.np; i0 += ntg.x) {
sumf += src0[i0];
} }
dst[0] = acc; sumf = simd_sum(sumf);
if (tiisg == 0) {
shmem_f32[sgitg] = sumf;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float total = 0;
if (sgitg == 0) {
float v = 0;
if (tpitg.x < nsg) {
v = shmem_f32[tpitg.x];
}
total = simd_sum(v);
if (tpitg.x == 0) {
dst[0] = total;
}
}
} }
template <bool norm> template <bool norm>
@ -6413,7 +6457,7 @@ kernel void kernel_rope_norm(
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0); const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f; const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta); rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
@ -6466,7 +6510,7 @@ kernel void kernel_rope_neox(
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0); const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f; const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta); rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
@ -6533,7 +6577,7 @@ kernel void kernel_rope_multi(
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0); const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f; const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta); rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
@ -6600,7 +6644,7 @@ kernel void kernel_rope_vision(
const float theta = theta_base * pow(args.freq_base, 2.0f * inv_ndims * p); const float theta = theta_base * pow(args.freq_base, 2.0f * inv_ndims * p);
// end of mrope // end of mrope
const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f; const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta); rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
@ -6810,6 +6854,97 @@ kernel void kernel_conv_transpose_1d<half>(
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]]); uint3 tgpg[[threadgroups_per_grid]]);
typedef void (conv_transpose_2d_t)(
constant ggml_metal_kargs_conv_transpose_2d & args,
device const float * src0,
device const float * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]]);
template <typename T>
kernel void kernel_conv_transpose_2d(
constant ggml_metal_kargs_conv_transpose_2d & args,
device const T * src0,
device const float * src1,
device char * dst,
threadgroup float * shared_sum [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t out_x = tgpig[0];
const int64_t out_y = tgpig[1];
const int64_t out_c = tgpig[2];
const int64_t kw = tpitg[0];
const int64_t kh = tpitg[1];
float v = 0.0f;
for (int64_t in_c = 0; in_c < args.IC; in_c++) {
int64_t in_y = out_y - kh;
if (in_y < 0 || in_y % args.s0) continue;
in_y /= args.s0;
if (in_y >= args.IH) continue;
int64_t in_x = out_x - kw;
if (in_x < 0 || in_x % args.s0) continue;
in_x /= args.s0;
if (in_x >= args.IW) continue;
const int64_t input_idx = (args.IW * args.IH) * in_c + (args.IW) * in_y + in_x;
const int64_t kernel_idx = (args.KH * args.KW * args.OC) * in_c + (args.KH * args.KW) * out_c + (args.KW) * kh + kw;
v += (float)src0[kernel_idx] * src1[input_idx];
}
const uint tid = tpitg.y * ntg.x + tpitg.x;
shared_sum[tid] = v;
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tid == 0) {
float total = 0.0f;
const uint num_threads = ntg.x * ntg.y;
for (uint i = 0; i < num_threads; i++) {
total += shared_sum[i];
}
device float * dst_ptr = (device float *) (dst + out_x*args.nb0 + out_y * args.nb1 + out_c*args.nb2);
dst_ptr[0] = total;
}
}
template [[host_name("kernel_conv_transpose_2d_f32_f32")]]
kernel void kernel_conv_transpose_2d<float>(
constant ggml_metal_kargs_conv_transpose_2d & args,
device const float * src0,
device const float * src1,
device char * dst,
threadgroup float * shared_sum [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]);
template [[host_name("kernel_conv_transpose_2d_f16_f32")]]
kernel void kernel_conv_transpose_2d<half>(
constant ggml_metal_kargs_conv_transpose_2d & args,
device const half * src0,
device const float * src1,
device char * dst,
threadgroup float * shared_sum [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]);
kernel void kernel_upscale_f32( kernel void kernel_upscale_f32(
constant ggml_metal_kargs_upscale & args, constant ggml_metal_kargs_upscale & args,
device const char * src0, device const char * src0,
@ -7938,8 +8073,30 @@ kernel void kernel_flash_attn_ext(
half, half4, simdgroup_half8x8 half, half4, simdgroup_half8x8
//float, float4, simdgroup_float8x8 //float, float4, simdgroup_float8x8
#define FA_TYPES_F32 \
half, half4, simdgroup_half8x8, \
float, float4x4, simdgroup_float8x8, \
float, float4x4, simdgroup_float8x8, \
float, simdgroup_float8x8, \
float, float2, simdgroup_float8x8, \
float, float4, simdgroup_float8x8
//half, half4, simdgroup_half8x8
typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>) flash_attn_ext_t; typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>) flash_attn_ext_t;
template [[host_name("kernel_flash_attn_ext_f32_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 32, 32>;
template [[host_name("kernel_flash_attn_ext_f32_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 40, 40>;
template [[host_name("kernel_flash_attn_ext_f32_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 64, 64>;
template [[host_name("kernel_flash_attn_ext_f32_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 80, 80>;
template [[host_name("kernel_flash_attn_ext_f32_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 96, 96>;
template [[host_name("kernel_flash_attn_ext_f32_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 112, 112>;
template [[host_name("kernel_flash_attn_ext_f32_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 128, 128>;
template [[host_name("kernel_flash_attn_ext_f32_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 192, 192>;
template [[host_name("kernel_flash_attn_ext_f32_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 192, 128>;
template [[host_name("kernel_flash_attn_ext_f32_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 256, 256>;
template [[host_name("kernel_flash_attn_ext_f32_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 576, 512>;
template [[host_name("kernel_flash_attn_ext_f16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 32, 32>;
template [[host_name("kernel_flash_attn_ext_f16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 40, 40>; template [[host_name("kernel_flash_attn_ext_f16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 40, 40>;
template [[host_name("kernel_flash_attn_ext_f16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>; template [[host_name("kernel_flash_attn_ext_f16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>;
template [[host_name("kernel_flash_attn_ext_f16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 80, 80>; template [[host_name("kernel_flash_attn_ext_f16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 80, 80>;
@ -7952,6 +8109,7 @@ template [[host_name("kernel_flash_attn_ext_f16_dk256_dv256")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 576, 512>; template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 576, 512>;
#if defined(GGML_METAL_HAS_BF16) #if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_flash_attn_ext_bf16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 32, 32>;
template [[host_name("kernel_flash_attn_ext_bf16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 40, 40>; template [[host_name("kernel_flash_attn_ext_bf16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 40, 40>;
template [[host_name("kernel_flash_attn_ext_bf16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>; template [[host_name("kernel_flash_attn_ext_bf16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
template [[host_name("kernel_flash_attn_ext_bf16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>; template [[host_name("kernel_flash_attn_ext_bf16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>;
@ -7964,6 +8122,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_dk256_dv256")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 576, 512>; template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 576, 512>;
#endif #endif
template [[host_name("kernel_flash_attn_ext_q4_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 32, 32>;
template [[host_name("kernel_flash_attn_ext_q4_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 40, 40>; template [[host_name("kernel_flash_attn_ext_q4_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q4_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>; template [[host_name("kernel_flash_attn_ext_q4_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q4_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 80, 80>; template [[host_name("kernel_flash_attn_ext_q4_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 80, 80>;
@ -7975,6 +8134,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv128")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_q4_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256, 256>; template [[host_name("kernel_flash_attn_ext_q4_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256, 256>;
template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 576, 512>; template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 576, 512>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 32, 32>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 40, 40>; template [[host_name("kernel_flash_attn_ext_q4_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64, 64>; template [[host_name("kernel_flash_attn_ext_q4_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80, 80>; template [[host_name("kernel_flash_attn_ext_q4_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80, 80>;
@ -7986,6 +8146,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv128")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_q4_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256, 256>; template [[host_name("kernel_flash_attn_ext_q4_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256, 256>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 576, 512>; template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 576, 512>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 32, 32>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 40, 40>; template [[host_name("kernel_flash_attn_ext_q5_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64, 64>; template [[host_name("kernel_flash_attn_ext_q5_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80, 80>; template [[host_name("kernel_flash_attn_ext_q5_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80, 80>;
@ -7997,6 +8158,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv128")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_q5_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256, 256>; template [[host_name("kernel_flash_attn_ext_q5_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256, 256>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 576, 512>; template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 576, 512>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 32, 32>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 40, 40>; template [[host_name("kernel_flash_attn_ext_q5_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64, 64>; template [[host_name("kernel_flash_attn_ext_q5_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80, 80>; template [[host_name("kernel_flash_attn_ext_q5_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80, 80>;
@ -8008,6 +8170,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv128")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_q5_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256, 256>; template [[host_name("kernel_flash_attn_ext_q5_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256, 256>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 576, 512>; template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 576, 512>;
template [[host_name("kernel_flash_attn_ext_q8_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 32, 32>;
template [[host_name("kernel_flash_attn_ext_q8_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 40, 40>; template [[host_name("kernel_flash_attn_ext_q8_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q8_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64, 64>; template [[host_name("kernel_flash_attn_ext_q8_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q8_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80, 80>; template [[host_name("kernel_flash_attn_ext_q8_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80, 80>;
@ -8543,8 +8706,28 @@ kernel void kernel_flash_attn_ext_vec(
float, float4, \ float, float4, \
float4 float4
#define FA_TYPES_F32 \
half4, \
float4, \
float4, \
float, \
float, float4, \
float4
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t; typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;
template [[host_name("kernel_flash_attn_ext_vec_f32_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 32, 32, 4>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 32, 32, 4>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 32, 32, 4>;
#endif
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 32, 32, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 32, 32, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 32, 32, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 32, 32, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 32, 32, 4>;
template [[host_name("kernel_flash_attn_ext_vec_f32_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 64, 64, 2>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 64, 64, 2>; template [[host_name("kernel_flash_attn_ext_vec_f16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 64, 64, 2>;
#if defined(GGML_METAL_HAS_BF16) #if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 64, 64, 2>; template [[host_name("kernel_flash_attn_ext_vec_bf16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 64, 64, 2>;
@ -8555,6 +8738,7 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk64_dv64")]] kernel flas
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 64, 64, 2>; template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 64, 64, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 64, 64, 2>; template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 64, 64, 2>;
template [[host_name("kernel_flash_attn_ext_vec_f32_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 96, 96, 4>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 96, 96, 4>; template [[host_name("kernel_flash_attn_ext_vec_f16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 96, 96, 4>;
#if defined(GGML_METAL_HAS_BF16) #if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 96, 96, 4>; template [[host_name("kernel_flash_attn_ext_vec_bf16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 96, 96, 4>;
@ -8565,6 +8749,7 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk96_dv96")]] kernel flas
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 96, 96, 4>; template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 96, 96, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 96, 96, 4>; template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 96, 96, 4>;
template [[host_name("kernel_flash_attn_ext_vec_f32_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 128, 128, 1>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 1>; template [[host_name("kernel_flash_attn_ext_vec_f16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 1>;
#if defined(GGML_METAL_HAS_BF16) #if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 128, 128, 1>; template [[host_name("kernel_flash_attn_ext_vec_bf16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 128, 128, 1>;
@ -8575,6 +8760,7 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk128_dv128")]] kernel flas
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 128, 128, 1>; template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 128, 128, 1>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 128, 128, 1>; template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 128, 128, 1>;
template [[host_name("kernel_flash_attn_ext_vec_f32_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 192, 192, 2>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 192, 2>; template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 192, 2>;
#if defined(GGML_METAL_HAS_BF16) #if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 192, 2>; template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 192, 2>;
@ -8585,6 +8771,7 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv192")]] kernel flas
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 192, 2>; template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 192, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 192, 2>; template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 192, 2>;
template [[host_name("kernel_flash_attn_ext_vec_f32_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 192, 128, 2>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 128, 2>; template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 128, 2>;
#if defined(GGML_METAL_HAS_BF16) #if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 128, 2>; template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 128, 2>;
@ -8595,6 +8782,7 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv128")]] kernel flas
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 128, 2>; template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 128, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 128, 2>; template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 128, 2>;
template [[host_name("kernel_flash_attn_ext_vec_f32_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 256, 256, 1>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 256, 256, 1>; template [[host_name("kernel_flash_attn_ext_vec_f16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 256, 256, 1>;
#if defined(GGML_METAL_HAS_BF16) #if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 256, 256, 1>; template [[host_name("kernel_flash_attn_ext_vec_bf16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 256, 256, 1>;
@ -8605,6 +8793,7 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk256_dv256")]] kernel flas
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 256, 256, 1>; template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 256, 256, 1>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 256, 256, 1>; template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 256, 256, 1>;
template [[host_name("kernel_flash_attn_ext_vec_f32_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 576, 512, 2>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 576, 512, 2>; template [[host_name("kernel_flash_attn_ext_vec_f16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 576, 512, 2>;
#if defined(GGML_METAL_HAS_BF16) #if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 576, 512, 2>; template [[host_name("kernel_flash_attn_ext_vec_bf16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 576, 512, 2>;

View File

@ -251,6 +251,7 @@ typedef struct {
int32_t sect_1; int32_t sect_1;
int32_t sect_2; int32_t sect_2;
int32_t sect_3; int32_t sect_3;
bool src2;
} ggml_metal_kargs_rope; } ggml_metal_kargs_rope;
typedef struct { typedef struct {
@ -513,6 +514,19 @@ typedef struct {
uint64_t nb1; uint64_t nb1;
} ggml_metal_kargs_conv_transpose_1d; } ggml_metal_kargs_conv_transpose_1d;
typedef struct {
int32_t IC;
int32_t IH;
int32_t IW;
int32_t KH;
int32_t KW;
int32_t OC;
int32_t s0;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
} ggml_metal_kargs_conv_transpose_2d;
typedef struct { typedef struct {
uint64_t ofs0; uint64_t ofs0;
uint64_t ofs1; uint64_t ofs1;

View File

@ -368,6 +368,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
{ {
n_fuse = ggml_metal_op_conv_transpose_1d(ctx, idx); n_fuse = ggml_metal_op_conv_transpose_1d(ctx, idx);
} break; } break;
case GGML_OP_CONV_TRANSPOSE_2D:
{
n_fuse = ggml_metal_op_conv_transpose_2d(ctx, idx);
} break;
case GGML_OP_UPSCALE: case GGML_OP_UPSCALE:
{ {
n_fuse = ggml_metal_op_upscale(ctx, idx); n_fuse = ggml_metal_op_upscale(ctx, idx);
@ -866,12 +870,25 @@ int ggml_metal_op_sum(ggml_metal_op_t ctx, int idx) {
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_sum(lib, op); ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_sum(lib, op);
int nth = 32; // SIMD width
while (nth < (int) n && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
nth *= 2;
}
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
nth = std::min(nth, (int) n);
const int nsg = (nth + 31) / 32;
ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, 1, 1, 1); ggml_metal_encoder_set_threadgroup_memory_size(enc, nsg * sizeof(float), 0);
ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, nth, 1, 1);
return 1; return 1;
} }
@ -2969,6 +2986,7 @@ int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) {
/* sect_1 =*/ sect_1, /* sect_1 =*/ sect_1,
/* sect_2 =*/ sect_2, /* sect_2 =*/ sect_2,
/* sect_3 =*/ sect_3, /* sect_3 =*/ sect_3,
/* src2 =*/ op->src[2] != nullptr,
}; };
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_rope(lib, op); ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_rope(lib, op);
@ -3104,6 +3122,62 @@ int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) {
return 1; return 1;
} }
int ggml_metal_op_conv_transpose_2d(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
const int32_t s0 = ((const int32_t *)(op->op_params))[0];
const int32_t IC = op->src[1]->ne[2];
const int32_t IH = op->src[1]->ne[1];
const int32_t IW = op->src[1]->ne[0];
const int32_t KH = op->src[0]->ne[1];
const int32_t KW = op->src[0]->ne[0];
const int32_t OW = op->ne[0];
const int32_t OH = op->ne[1];
const int32_t OC = op->ne[2];
ggml_metal_kargs_conv_transpose_2d args = {
/*.IC =*/ IC,
/*.IH =*/ IH,
/*.IW =*/ IW,
/*.KH =*/ KH,
/*.KW =*/ KW,
/*.OC =*/ OC,
/*.s0 =*/ s0,
/*.nb0 =*/ nb0,
/*.nb1 =*/ nb1,
/*.nb2 =*/ nb2,
};
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_conv_transpose_2d(lib, op);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
// Metal requires buffer size to be multiple of 16 bytes
const size_t smem = GGML_PAD(KW * KH * sizeof(float), 16);
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
ggml_metal_encoder_dispatch_threadgroups(enc, OW, OH, OC, KW, KH, 1);
return 1;
}
int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) { int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx); ggml_tensor * op = ctx->node(idx);

View File

@ -71,6 +71,7 @@ int ggml_metal_op_norm (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_rope (ggml_metal_op_t ctx, int idx); int ggml_metal_op_rope (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_im2col (ggml_metal_op_t ctx, int idx); int ggml_metal_op_im2col (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_conv_transpose_1d (ggml_metal_op_t ctx, int idx); int ggml_metal_op_conv_transpose_1d (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_conv_transpose_2d (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_upscale (ggml_metal_op_t ctx, int idx); int ggml_metal_op_upscale (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_pad (ggml_metal_op_t ctx, int idx); int ggml_metal_op_pad (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_pad_reflect_1d (ggml_metal_op_t ctx, int idx); int ggml_metal_op_pad_reflect_1d (ggml_metal_op_t ctx, int idx);

View File

@ -1727,18 +1727,48 @@ kernel void kernel_op_sum_f32(
constant ggml_metal_kargs_sum & args, constant ggml_metal_kargs_sum & args,
device const float * src0, device const float * src0,
device float * dst, device float * dst,
ushort tiitg[[thread_index_in_threadgroup]]) { threadgroup float * shmem_f32 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
if (tiitg != 0) { if (args.np == 0) {
return; return;
} }
float acc = 0.0f; const uint nsg = (ntg.x + 31) / 32;
for (ulong i = 0; i < args.np; ++i) {
acc += src0[i]; float sumf = 0;
for (int64_t i0 = tpitg.x; i0 < args.np; i0 += ntg.x) {
sumf += src0[i0];
} }
dst[0] = acc; sumf = simd_sum(sumf);
if (tiisg == 0) {
shmem_f32[sgitg] = sumf;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float total = 0;
if (sgitg == 0) {
float v = 0;
if (tpitg.x < nsg) {
v = shmem_f32[tpitg.x];
}
total = simd_sum(v);
if (tpitg.x == 0) {
dst[0] = total;
}
}
} }
template <bool norm> template <bool norm>
@ -3748,7 +3778,7 @@ kernel void kernel_rope_norm(
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0); const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f; const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta); rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
@ -3801,7 +3831,7 @@ kernel void kernel_rope_neox(
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0); const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f; const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta); rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
@ -3868,7 +3898,7 @@ kernel void kernel_rope_multi(
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0); const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f; const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta); rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
@ -3935,7 +3965,7 @@ kernel void kernel_rope_vision(
const float theta = theta_base * pow(args.freq_base, 2.0f * inv_ndims * p); const float theta = theta_base * pow(args.freq_base, 2.0f * inv_ndims * p);
// end of mrope // end of mrope
const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f; const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta); rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
@ -4145,6 +4175,97 @@ kernel void kernel_conv_transpose_1d<half>(
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]]); uint3 tgpg[[threadgroups_per_grid]]);
typedef void (conv_transpose_2d_t)(
constant ggml_metal_kargs_conv_transpose_2d & args,
device const float * src0,
device const float * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]]);
template <typename T>
kernel void kernel_conv_transpose_2d(
constant ggml_metal_kargs_conv_transpose_2d & args,
device const T * src0,
device const float * src1,
device char * dst,
threadgroup float * shared_sum [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t out_x = tgpig[0];
const int64_t out_y = tgpig[1];
const int64_t out_c = tgpig[2];
const int64_t kw = tpitg[0];
const int64_t kh = tpitg[1];
float v = 0.0f;
for (int64_t in_c = 0; in_c < args.IC; in_c++) {
int64_t in_y = out_y - kh;
if (in_y < 0 || in_y % args.s0) continue;
in_y /= args.s0;
if (in_y >= args.IH) continue;
int64_t in_x = out_x - kw;
if (in_x < 0 || in_x % args.s0) continue;
in_x /= args.s0;
if (in_x >= args.IW) continue;
const int64_t input_idx = (args.IW * args.IH) * in_c + (args.IW) * in_y + in_x;
const int64_t kernel_idx = (args.KH * args.KW * args.OC) * in_c + (args.KH * args.KW) * out_c + (args.KW) * kh + kw;
v += (float)src0[kernel_idx] * src1[input_idx];
}
const uint tid = tpitg.y * ntg.x + tpitg.x;
shared_sum[tid] = v;
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tid == 0) {
float total = 0.0f;
const uint num_threads = ntg.x * ntg.y;
for (uint i = 0; i < num_threads; i++) {
total += shared_sum[i];
}
device float * dst_ptr = (device float *) (dst + out_x*args.nb0 + out_y * args.nb1 + out_c*args.nb2);
dst_ptr[0] = total;
}
}
template [[host_name("kernel_conv_transpose_2d_f32_f32")]]
kernel void kernel_conv_transpose_2d<float>(
constant ggml_metal_kargs_conv_transpose_2d & args,
device const float * src0,
device const float * src1,
device char * dst,
threadgroup float * shared_sum [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]);
template [[host_name("kernel_conv_transpose_2d_f16_f32")]]
kernel void kernel_conv_transpose_2d<half>(
constant ggml_metal_kargs_conv_transpose_2d & args,
device const half * src0,
device const float * src1,
device char * dst,
threadgroup float * shared_sum [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]);
kernel void kernel_upscale_f32( kernel void kernel_upscale_f32(
constant ggml_metal_kargs_upscale & args, constant ggml_metal_kargs_upscale & args,
device const char * src0, device const char * src0,
@ -5273,8 +5394,30 @@ kernel void kernel_flash_attn_ext(
half, half4, simdgroup_half8x8 half, half4, simdgroup_half8x8
//float, float4, simdgroup_float8x8 //float, float4, simdgroup_float8x8
#define FA_TYPES_F32 \
half, half4, simdgroup_half8x8, \
float, float4x4, simdgroup_float8x8, \
float, float4x4, simdgroup_float8x8, \
float, simdgroup_float8x8, \
float, float2, simdgroup_float8x8, \
float, float4, simdgroup_float8x8
//half, half4, simdgroup_half8x8
typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>) flash_attn_ext_t; typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>) flash_attn_ext_t;
template [[host_name("kernel_flash_attn_ext_f32_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 32, 32>;
template [[host_name("kernel_flash_attn_ext_f32_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 40, 40>;
template [[host_name("kernel_flash_attn_ext_f32_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 64, 64>;
template [[host_name("kernel_flash_attn_ext_f32_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 80, 80>;
template [[host_name("kernel_flash_attn_ext_f32_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 96, 96>;
template [[host_name("kernel_flash_attn_ext_f32_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 112, 112>;
template [[host_name("kernel_flash_attn_ext_f32_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 128, 128>;
template [[host_name("kernel_flash_attn_ext_f32_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 192, 192>;
template [[host_name("kernel_flash_attn_ext_f32_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 192, 128>;
template [[host_name("kernel_flash_attn_ext_f32_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 256, 256>;
template [[host_name("kernel_flash_attn_ext_f32_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 576, 512>;
template [[host_name("kernel_flash_attn_ext_f16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 32, 32>;
template [[host_name("kernel_flash_attn_ext_f16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 40, 40>; template [[host_name("kernel_flash_attn_ext_f16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 40, 40>;
template [[host_name("kernel_flash_attn_ext_f16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>; template [[host_name("kernel_flash_attn_ext_f16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>;
template [[host_name("kernel_flash_attn_ext_f16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 80, 80>; template [[host_name("kernel_flash_attn_ext_f16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 80, 80>;
@ -5287,6 +5430,7 @@ template [[host_name("kernel_flash_attn_ext_f16_dk256_dv256")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 576, 512>; template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 576, 512>;
#if defined(GGML_METAL_HAS_BF16) #if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_flash_attn_ext_bf16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 32, 32>;
template [[host_name("kernel_flash_attn_ext_bf16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 40, 40>; template [[host_name("kernel_flash_attn_ext_bf16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 40, 40>;
template [[host_name("kernel_flash_attn_ext_bf16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>; template [[host_name("kernel_flash_attn_ext_bf16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
template [[host_name("kernel_flash_attn_ext_bf16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>; template [[host_name("kernel_flash_attn_ext_bf16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>;
@ -5299,6 +5443,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_dk256_dv256")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 576, 512>; template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 576, 512>;
#endif #endif
template [[host_name("kernel_flash_attn_ext_q4_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 32, 32>;
template [[host_name("kernel_flash_attn_ext_q4_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 40, 40>; template [[host_name("kernel_flash_attn_ext_q4_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q4_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>; template [[host_name("kernel_flash_attn_ext_q4_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q4_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 80, 80>; template [[host_name("kernel_flash_attn_ext_q4_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 80, 80>;
@ -5310,6 +5455,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv128")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_q4_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256, 256>; template [[host_name("kernel_flash_attn_ext_q4_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256, 256>;
template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 576, 512>; template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 576, 512>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 32, 32>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 40, 40>; template [[host_name("kernel_flash_attn_ext_q4_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64, 64>; template [[host_name("kernel_flash_attn_ext_q4_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80, 80>; template [[host_name("kernel_flash_attn_ext_q4_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80, 80>;
@ -5321,6 +5467,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv128")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_q4_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256, 256>; template [[host_name("kernel_flash_attn_ext_q4_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256, 256>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 576, 512>; template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 576, 512>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 32, 32>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 40, 40>; template [[host_name("kernel_flash_attn_ext_q5_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64, 64>; template [[host_name("kernel_flash_attn_ext_q5_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80, 80>; template [[host_name("kernel_flash_attn_ext_q5_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80, 80>;
@ -5332,6 +5479,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv128")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_q5_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256, 256>; template [[host_name("kernel_flash_attn_ext_q5_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256, 256>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 576, 512>; template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 576, 512>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 32, 32>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 40, 40>; template [[host_name("kernel_flash_attn_ext_q5_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64, 64>; template [[host_name("kernel_flash_attn_ext_q5_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80, 80>; template [[host_name("kernel_flash_attn_ext_q5_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80, 80>;
@ -5343,6 +5491,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv128")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_q5_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256, 256>; template [[host_name("kernel_flash_attn_ext_q5_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256, 256>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 576, 512>; template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 576, 512>;
template [[host_name("kernel_flash_attn_ext_q8_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 32, 32>;
template [[host_name("kernel_flash_attn_ext_q8_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 40, 40>; template [[host_name("kernel_flash_attn_ext_q8_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q8_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64, 64>; template [[host_name("kernel_flash_attn_ext_q8_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q8_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80, 80>; template [[host_name("kernel_flash_attn_ext_q8_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80, 80>;
@ -5878,8 +6027,28 @@ kernel void kernel_flash_attn_ext_vec(
float, float4, \ float, float4, \
float4 float4
#define FA_TYPES_F32 \
half4, \
float4, \
float4, \
float, \
float, float4, \
float4
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t; typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;
template [[host_name("kernel_flash_attn_ext_vec_f32_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 32, 32, 4>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 32, 32, 4>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 32, 32, 4>;
#endif
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 32, 32, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 32, 32, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 32, 32, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 32, 32, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 32, 32, 4>;
template [[host_name("kernel_flash_attn_ext_vec_f32_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 64, 64, 2>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 64, 64, 2>; template [[host_name("kernel_flash_attn_ext_vec_f16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 64, 64, 2>;
#if defined(GGML_METAL_HAS_BF16) #if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 64, 64, 2>; template [[host_name("kernel_flash_attn_ext_vec_bf16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 64, 64, 2>;
@ -5890,6 +6059,7 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk64_dv64")]] kernel flas
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 64, 64, 2>; template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 64, 64, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 64, 64, 2>; template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 64, 64, 2>;
template [[host_name("kernel_flash_attn_ext_vec_f32_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 96, 96, 4>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 96, 96, 4>; template [[host_name("kernel_flash_attn_ext_vec_f16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 96, 96, 4>;
#if defined(GGML_METAL_HAS_BF16) #if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 96, 96, 4>; template [[host_name("kernel_flash_attn_ext_vec_bf16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 96, 96, 4>;
@ -5900,6 +6070,7 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk96_dv96")]] kernel flas
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 96, 96, 4>; template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 96, 96, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 96, 96, 4>; template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 96, 96, 4>;
template [[host_name("kernel_flash_attn_ext_vec_f32_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 128, 128, 1>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 1>; template [[host_name("kernel_flash_attn_ext_vec_f16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 1>;
#if defined(GGML_METAL_HAS_BF16) #if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 128, 128, 1>; template [[host_name("kernel_flash_attn_ext_vec_bf16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 128, 128, 1>;
@ -5910,6 +6081,7 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk128_dv128")]] kernel flas
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 128, 128, 1>; template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 128, 128, 1>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 128, 128, 1>; template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 128, 128, 1>;
template [[host_name("kernel_flash_attn_ext_vec_f32_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 192, 192, 2>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 192, 2>; template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 192, 2>;
#if defined(GGML_METAL_HAS_BF16) #if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 192, 2>; template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 192, 2>;
@ -5920,6 +6092,7 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv192")]] kernel flas
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 192, 2>; template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 192, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 192, 2>; template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 192, 2>;
template [[host_name("kernel_flash_attn_ext_vec_f32_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 192, 128, 2>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 128, 2>; template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 128, 2>;
#if defined(GGML_METAL_HAS_BF16) #if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 128, 2>; template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 128, 2>;
@ -5930,6 +6103,7 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv128")]] kernel flas
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 128, 2>; template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 128, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 128, 2>; template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 128, 2>;
template [[host_name("kernel_flash_attn_ext_vec_f32_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 256, 256, 1>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 256, 256, 1>; template [[host_name("kernel_flash_attn_ext_vec_f16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 256, 256, 1>;
#if defined(GGML_METAL_HAS_BF16) #if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 256, 256, 1>; template [[host_name("kernel_flash_attn_ext_vec_bf16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 256, 256, 1>;
@ -5940,6 +6114,7 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk256_dv256")]] kernel flas
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 256, 256, 1>; template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 256, 256, 1>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 256, 256, 1>; template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 256, 256, 1>;
template [[host_name("kernel_flash_attn_ext_vec_f32_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 576, 512, 2>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 576, 512, 2>; template [[host_name("kernel_flash_attn_ext_vec_f16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 576, 512, 2>;
#if defined(GGML_METAL_HAS_BF16) #if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 576, 512, 2>; template [[host_name("kernel_flash_attn_ext_vec_bf16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 576, 512, 2>;

View File

@ -1,9 +1,18 @@
cmake_minimum_required(VERSION 3.19) cmake_minimum_required(VERSION 3.19)
cmake_policy(SET CMP0114 NEW) cmake_policy(SET CMP0114 NEW)
cmake_policy(SET CMP0116 NEW) cmake_policy(SET CMP0116 NEW)
if (POLICY CMP0147)
# Parallel build custom build steps
cmake_policy(SET CMP0147 NEW)
endif()
find_package(Vulkan COMPONENTS glslc REQUIRED) find_package(Vulkan COMPONENTS glslc REQUIRED)
if (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC")
# Parallel build object files
add_definitions(/MP)
endif()
function(detect_host_compiler) function(detect_host_compiler)
if (CMAKE_HOST_SYSTEM_NAME STREQUAL "Windows") if (CMAKE_HOST_SYSTEM_NAME STREQUAL "Windows")
find_program(HOST_C_COMPILER NAMES cl gcc clang NO_CMAKE_FIND_ROOT_PATH) find_program(HOST_C_COMPILER NAMES cl gcc clang NO_CMAKE_FIND_ROOT_PATH)

View File

@ -97,8 +97,6 @@ static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
#define GGML_VK_MAX_NODES 8192 #define GGML_VK_MAX_NODES 8192
#define MAX_VK_BUFFERS 256
#define VK_CHECK(err, msg) \ #define VK_CHECK(err, msg) \
do { \ do { \
vk::Result err_ = (err); \ vk::Result err_ = (err); \
@ -387,6 +385,14 @@ enum shader_reduction_mode {
static constexpr uint32_t num_argsort_pipelines = 11; static constexpr uint32_t num_argsort_pipelines = 11;
static constexpr uint32_t max_argsort_cols = 1 << (num_argsort_pipelines-1); static constexpr uint32_t max_argsort_cols = 1 << (num_argsort_pipelines-1);
static constexpr uint32_t num_topk_moe_pipelines = 10;
static constexpr std::array topk_moe_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
GGML_OP_SUM_ROWS, GGML_OP_DIV, GGML_OP_RESHAPE };
static constexpr std::array topk_moe { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
GGML_OP_VIEW, GGML_OP_GET_ROWS };
struct vk_device_struct { struct vk_device_struct {
std::recursive_mutex mutex; std::recursive_mutex mutex;
@ -584,6 +590,9 @@ struct vk_device_struct {
vk_pipeline pipeline_pool2d_f32; vk_pipeline pipeline_pool2d_f32;
vk_pipeline pipeline_rwkv_wkv6_f32; vk_pipeline pipeline_rwkv_wkv6_f32;
vk_pipeline pipeline_rwkv_wkv7_f32; vk_pipeline pipeline_rwkv_wkv7_f32;
vk_pipeline pipeline_ssm_scan_f32_d128;
vk_pipeline pipeline_ssm_scan_f32_d256;
vk_pipeline pipeline_ssm_conv_f32;
vk_pipeline pipeline_opt_step_adamw_f32; vk_pipeline pipeline_opt_step_adamw_f32;
vk_pipeline pipeline_opt_step_sgd_f32; vk_pipeline pipeline_opt_step_sgd_f32;
vk_pipeline pipeline_conv2d_f32[CONV_SHAPE_COUNT]; vk_pipeline pipeline_conv2d_f32[CONV_SHAPE_COUNT];
@ -597,6 +606,9 @@ struct vk_device_struct {
vk_pipeline pipeline_flash_attn_split_k_reduce; vk_pipeline pipeline_flash_attn_split_k_reduce;
// [2] is {!norm, norm}
vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][2];
std::vector<vk_pipeline_ref> all_pipelines; std::vector<vk_pipeline_ref> all_pipelines;
std::vector<std::tuple<void*, size_t, vk_buffer>> pinned_memory; std::vector<std::tuple<void*, size_t, vk_buffer>> pinned_memory;
@ -940,6 +952,11 @@ struct vk_op_multi_add_push_constants {
static_assert(MAX_PARAMETER_COUNT == 12); static_assert(MAX_PARAMETER_COUNT == 12);
static_assert(sizeof(vk_op_multi_add_push_constants) <= 256); static_assert(sizeof(vk_op_multi_add_push_constants) <= 256);
struct vk_op_topk_moe_push_constants {
uint32_t n_rows;
uint32_t n_expert_used;
};
struct vk_op_add_id_push_constants { struct vk_op_add_id_push_constants {
uint32_t ne0; uint32_t ne0;
uint32_t ne1; uint32_t ne1;
@ -1089,6 +1106,19 @@ struct vk_op_rwkv_wkv7_push_constants {
uint32_t C; uint32_t C;
uint32_t H; uint32_t H;
}; };
struct vk_op_ssm_scan_push_constants {
uint32_t nb02, nb03, nb12, nb13;
uint32_t nb21, nb22, nb31;
uint32_t nb42, nb43, nb52, nb53;
uint32_t s_off;
uint32_t n_head, d_head, n_group, n_tok;
};
struct vk_op_ssm_conv_push_constants {
uint32_t nb01, nb02;
uint32_t nb11;
uint32_t dst_nb0, dst_nb1, dst_nb2;
uint32_t nc, ncs, nr, n_t, n_s;
};
struct vk_op_conv2d_push_constants { struct vk_op_conv2d_push_constants {
uint32_t Cout; uint32_t Cout;
@ -1281,7 +1311,6 @@ struct ggml_vk_garbage_collector {
std::vector<vk_semaphore> tl_semaphores; std::vector<vk_semaphore> tl_semaphores;
std::vector<vk_semaphore> semaphores; std::vector<vk_semaphore> semaphores;
std::vector<vk::Event> events; std::vector<vk::Event> events;
std::vector<vk_buffer> temp_buffers;
std::vector<vk_context> contexts; std::vector<vk_context> contexts;
}; };
@ -1452,8 +1481,6 @@ struct ggml_backend_vk_context {
// and set to true after the buffer contents are consumed. // and set to true after the buffer contents are consumed.
bool prealloc_x_need_sync, prealloc_y_need_sync, prealloc_split_k_need_sync; bool prealloc_x_need_sync, prealloc_y_need_sync, prealloc_split_k_need_sync;
vk_buffer buffer_pool[MAX_VK_BUFFERS];
vk_context_ref compute_ctx; vk_context_ref compute_ctx;
vk_context_ref transfer_ctx; vk_context_ref transfer_ctx;
@ -2651,11 +2678,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
} \ } \
} }
CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, )
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, ) CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, ) CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, ) CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) #if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
if (device->coopmat1_fa_support) { if (device->coopmat1_fa_support) {
CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT1, _cm1)
CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT1, _cm1) CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT1, _cm1)
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT1, _cm1) CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT1, _cm1)
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT1, _cm1) CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT1, _cm1)
@ -2663,6 +2692,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
#endif #endif
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) #if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
if (device->coopmat2) { if (device->coopmat2) {
CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT2, _cm2)
CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT2, _cm2) CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT2, _cm2)
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT2, _cm2) CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT2, _cm2)
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT2, _cm2) CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT2, _cm2)
@ -3590,6 +3620,16 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1, true, true);
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true);
} else {
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1, true, true);
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true);
}
ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 3, sizeof(vk_op_ssm_conv_push_constants), {32, 1, 1}, {32}, 1);
ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_opt_step_sgd_f32, "opt_step_sgd_f32", opt_step_sgd_f32_len, opt_step_sgd_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_opt_step_sgd_f32, "opt_step_sgd_f32", opt_step_sgd_f32_len, opt_step_sgd_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
@ -3700,6 +3740,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f16_f32, "conv2d_dw_whcn_f16_f32", conv2d_dw_whcn_f16_f32_len, conv2d_dw_whcn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f16_f32, "conv2d_dw_whcn_f16_f32", conv2d_dw_whcn_f16_f32_len, conv2d_dw_whcn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, "conv2d_dw_cwhn_f16_f32", conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, "conv2d_dw_cwhn_f16_f32", conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) {
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][0], "topk_moe_f32_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0}, 1, true, true);
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][1], "topk_moe_f32_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 1}, 1, true, true);
}
for (auto &c : compiles) { for (auto &c : compiles) {
c.wait(); c.wait();
} }
@ -4690,7 +4735,14 @@ static void ggml_vk_instance_init() {
vk::PhysicalDeviceIDProperties old_id; vk::PhysicalDeviceIDProperties old_id;
old_props.pNext = &old_id; old_props.pNext = &old_id;
devices[k].getProperties2(&old_props); devices[k].getProperties2(&old_props);
return std::equal(std::begin(old_id.deviceUUID), std::end(old_id.deviceUUID), std::begin(new_id.deviceUUID));
bool equals = std::equal(std::begin(old_id.deviceUUID), std::end(old_id.deviceUUID), std::begin(new_id.deviceUUID));
equals = equals || (
old_id.deviceLUIDValid && new_id.deviceLUIDValid &&
std::equal(std::begin(old_id.deviceLUID), std::end(old_id.deviceLUID), std::begin(new_id.deviceLUID))
);
return equals;
} }
); );
if (old_device == vk_instance.device_indices.end()) { if (old_device == vk_instance.device_indices.end()) {
@ -4728,6 +4780,7 @@ static void ggml_vk_instance_init() {
#endif #endif
break; break;
} }
driver_priorities[vk::DriverId::eMesaDozen] = 100;
if (driver_priorities.count(old_driver.driverID)) { if (driver_priorities.count(old_driver.driverID)) {
old_priority = driver_priorities[old_driver.driverID]; old_priority = driver_priorities[old_driver.driverID];
@ -5101,71 +5154,6 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context
return ctx->device->pipeline_dequant_mul_mat_vec_id_f32[a_type]; return ctx->device->pipeline_dequant_mul_mat_vec_id_f32[a_type];
} }
static vk_buffer ggml_vk_pool_malloc(ggml_backend_vk_context * ctx, size_t size) {
VK_LOG_DEBUG("ggml_vk_pool_malloc(" << size << ")");
VK_LOG_MEMORY("ggml_vk_pool_malloc");
int best_i = -1;
size_t best_size = std::numeric_limits<size_t>::max(); //smallest unused buffer that fits our needs
int worst_i = -1;
size_t worst_size = 0; //largest unused buffer seen so far
for (int i = 0; i < MAX_VK_BUFFERS; ++i) {
vk_buffer &b = ctx->buffer_pool[i];
if (b != nullptr && b->size >= size && b->size < best_size) {
best_i = i;
best_size = b->size;
}
if (b != nullptr && b->size > worst_size) {
worst_i = i;
worst_size = b->size;
}
}
if(best_i != -1) {
//found the smallest buffer that fits our needs
vk_buffer b = ctx->buffer_pool[best_i];
ctx->buffer_pool[best_i].reset();
return b;
}
if(worst_i != -1) {
//no buffer that fits our needs, resize largest one to save memory
vk_buffer& b = ctx->buffer_pool[worst_i];
ggml_vk_destroy_buffer(b);
}
return ggml_vk_create_buffer_device(ctx->device, size);
}
static void ggml_vk_pool_free(ggml_backend_vk_context * ctx, vk_buffer& buffer) {
VK_LOG_DEBUG("ggml_vk_pool_free(" << buffer->size << ")");
for (int i = 0; i < MAX_VK_BUFFERS; ++i) {
vk_buffer& b = ctx->buffer_pool[i];
if (b == nullptr) {
b = buffer;
return;
}
}
std::cerr << "ggml_vulkan: WARNING: vk buffer pool full, increase MAX_VK_BUFFERS" << std::endl;
ggml_vk_destroy_buffer(buffer);
}
// Returns an available temporary buffer that may only be used temporarily, it will be reused
static vk_buffer ggml_vk_create_buffer_temp(ggml_backend_vk_context * ctx, size_t size) {
// Try to find existing temp buffer with enough capacity
for (auto& buffer : ctx->gc.temp_buffers) {
if (buffer->size >= size) {
return buffer;
}
}
VK_LOG_MEMORY("ggml_vk_create_buffer_temp(" << size << ")");
// Otherwise create new buffer
vk_buffer buf = ggml_vk_pool_malloc(ctx, size);
ctx->gc.temp_buffers.push_back(buf);
return buf;
}
static void * ggml_vk_host_malloc(vk_device& device, size_t size) { static void * ggml_vk_host_malloc(vk_device& device, size_t size) {
VK_LOG_MEMORY("ggml_vk_host_malloc(" << size << ")"); VK_LOG_MEMORY("ggml_vk_host_malloc(" << size << ")");
vk_buffer buf = ggml_vk_create_buffer(device, size, vk_buffer buf = ggml_vk_create_buffer(device, size,
@ -7459,8 +7447,16 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
} }
const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type)); const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type));
const uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type)); uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type));
const uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type)); uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type));
// For F32, the shader treats it as a block of size 4 (for vec4 loads)
if (k->type == GGML_TYPE_F32) {
k_stride /= 4;
}
if (v->type == GGML_TYPE_F32) {
v_stride /= 4;
}
uint32_t alignment = fa_align(path, HSK, HSV, k->type, small_rows); uint32_t alignment = fa_align(path, HSK, HSV, k->type, small_rows);
bool aligned = (KV % alignment) == 0 && bool aligned = (KV % alignment) == 0 &&
@ -7974,6 +7970,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32); GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32);
if (ctx->num_additional_fused_ops) {
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
GGML_ASSERT(idx < num_topk_moe_pipelines);
bool with_norm = ctx->num_additional_fused_ops == topk_moe_norm.size() - 1;
return ctx->device->pipeline_topk_moe[idx][with_norm];
}
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) { if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_wg512 : ctx->device->pipeline_soft_max_f32; return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_wg512 : ctx->device->pipeline_soft_max_f32;
} }
@ -8089,6 +8092,21 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_rwkv_wkv7_f32; return ctx->device->pipeline_rwkv_wkv7_f32;
} }
return nullptr; return nullptr;
case GGML_OP_SSM_SCAN:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
const uint32_t d_state = src0->ne[0];
if (d_state == 128) {
return ctx->device->pipeline_ssm_scan_f32_d128;
} else if (d_state == 256) {
return ctx->device->pipeline_ssm_scan_f32_d256;
}
}
return nullptr;
case GGML_OP_SSM_CONV:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_ssm_conv_f32;
}
return nullptr;
case GGML_OP_OPT_STEP_ADAMW: case GGML_OP_OPT_STEP_ADAMW:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_opt_step_adamw_f32; return ctx->device->pipeline_opt_step_adamw_f32;
@ -8583,6 +8601,14 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
} }
} }
break; break;
case GGML_OP_SSM_CONV:
{
const uint32_t nr = src0->ne[1];
const uint32_t n_t = dst->ne[1];
const uint32_t n_s = dst->ne[2];
elements = { nr, n_t, n_s };
}
break;
default: default:
elements = { (uint32_t)ggml_nelements(src0), 1, 1 }; elements = { (uint32_t)ggml_nelements(src0), 1, 1 };
break; break;
@ -9029,6 +9055,117 @@ static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx,
); );
} }
static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const ggml_tensor * src2 = dst->src[2];
const ggml_tensor * src3 = dst->src[3];
const ggml_tensor * src4 = dst->src[4];
const ggml_tensor * src5 = dst->src[5];
GGML_ASSERT(dst->buffer != nullptr);
const uint32_t head_dim = src0->ne[1];
const uint32_t n_head = src1->ne[1];
const uint32_t n_group = src4->ne[1];
const uint32_t n_tok = src1->ne[2];
const uint32_t n_seq = src1->ne[3];
bool is_mamba2 = (src3->nb[1] == sizeof(float));
GGML_ASSERT(is_mamba2);
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, dst, dst->op);
GGML_ASSERT(pipeline != nullptr);
if (dryrun) {
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
return;
}
const int64_t s_off = ggml_nelements(src1) * sizeof(float);
const vk_op_ssm_scan_push_constants pc = {
(uint32_t)src0->nb[2], (uint32_t)src0->nb[3],
(uint32_t)src1->nb[2], (uint32_t)src1->nb[3],
(uint32_t)src2->nb[1], (uint32_t)src2->nb[2],
(uint32_t)src3->nb[1],
(uint32_t)src4->nb[2], (uint32_t)src4->nb[3],
(uint32_t)src5->nb[2], (uint32_t)src5->nb[3],
(uint32_t)s_off,
n_head, head_dim, n_group, n_tok
};
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
ggml_backend_vk_buffer_context * src_buf_ctxs[GGML_MAX_SRC];
for (int i = 0; i < GGML_MAX_SRC && dst->src[i] != nullptr; i++) {
src_buf_ctxs[i] = (ggml_backend_vk_buffer_context *)dst->src[i]->buffer->context;
}
vk_buffer d_D = nullptr, d_srcs[GGML_MAX_SRC] = { nullptr };
size_t dst_offset = 0, src_offsets[GGML_MAX_SRC] = { 0 };
bool dst_uma = false, srcs_uma[GGML_MAX_SRC] = { false };
if (ctx->device->uma) {
for (int i = 0; i < GGML_MAX_SRC && dst->src[i] != nullptr; i++) {
ggml_vk_host_get(ctx->device, dst->src[i]->data, d_srcs[i], src_offsets[i]);
srcs_uma[i] = d_srcs[i] != nullptr;
}
ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset);
dst_uma = d_D != nullptr;
}
if (!dst_uma) {
d_D = dst_buf_ctx->dev_buffer;
dst_offset = vk_tensor_offset(dst) + dst->view_offs;
}
for (int i = 0; i < GGML_MAX_SRC && dst->src[i] != nullptr; i++) {
if (!srcs_uma[i]) {
d_srcs[i] = src_buf_ctxs[i]->dev_buffer;
src_offsets[i] = vk_tensor_offset(dst->src[i]) + dst->src[i]->view_offs;
}
}
size_t dst_size = ggml_nbytes(dst);
size_t src_sizes[GGML_MAX_SRC];
for (int i = 0; i < GGML_MAX_SRC && dst->src[i] != nullptr; i++) {
src_sizes[i] = ggml_nbytes(dst->src[i]);
}
std::array<uint32_t, 3> elements;
const int splitH = 16;
const uint32_t num_workgroups_x = CEIL_DIV(n_head * head_dim, splitH);
const uint32_t num_workgroups_y = n_seq;
elements = { num_workgroups_x, num_workgroups_y, 1 };
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] },
vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] },
vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] },
vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] },
vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] },
vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] },
vk_subbuffer{ d_srcs[6], src_offsets[6], src_sizes[6] },
vk_subbuffer{ d_D, dst_offset, dst_size }
}, pc, elements);
}
static void ggml_vk_ssm_conv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
ggml_vk_op_f32<vk_op_ssm_conv_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SSM_CONV, {
(uint32_t)src0->nb[1], (uint32_t)src0->nb[2],
(uint32_t)src1->nb[1],
(uint32_t)dst->nb[0], (uint32_t)dst->nb[1], (uint32_t)dst->nb[2],
(uint32_t)src1->ne[0],
(uint32_t)src0->ne[0],
(uint32_t)src0->ne[1],
(uint32_t)dst->ne[1],
(uint32_t)dst->ne[2],
}, dryrun);
}
static void ggml_vk_op_f32_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_push_constants&& pc, bool dryrun = false) { static void ggml_vk_op_f32_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_push_constants&& pc, bool dryrun = false) {
const ggml_tensor * x = dst->src[0]; const ggml_tensor * x = dst->src[0];
const ggml_tensor * g = dst->src[1]; const ggml_tensor * g = dst->src[1];
@ -9425,6 +9562,87 @@ static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& sub
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), op_params[0], op_params[1] }, dryrun); ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), op_params[0], op_params[1] }, dryrun);
} }
static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx, bool dryrun = false) {
bool with_norm = ctx->num_additional_fused_ops == topk_moe_norm.size() - 1;
ggml_tensor * logits = cgraph->nodes[node_idx + 0]->src[0];
ggml_tensor * weights = with_norm ? cgraph->nodes[node_idx + 8] : cgraph->nodes[node_idx + 4];
ggml_tensor * ids = cgraph->nodes[node_idx + 3];
GGML_ASSERT(logits->type == GGML_TYPE_F32);
GGML_ASSERT(weights->type == GGML_TYPE_F32);
GGML_ASSERT(ids->type == GGML_TYPE_I32);
const int n_experts = logits->ne[0];
const int n_rows = logits->ne[1];
const int n_expert_used = weights->ne[1];
GGML_ASSERT(ids->nb[1] / ggml_type_size(ids->type) == (size_t) n_experts);
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, nullptr, nullptr, nullptr, cgraph->nodes[node_idx], GGML_OP_SOFT_MAX);
if (dryrun) {
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
return;
}
ggml_backend_vk_buffer_context * logits_buf_ctx = (ggml_backend_vk_buffer_context *)logits->buffer->context;
ggml_backend_vk_buffer_context * weights_buf_ctx = (ggml_backend_vk_buffer_context *)weights->buffer->context;
ggml_backend_vk_buffer_context * ids_buf_ctx = (ggml_backend_vk_buffer_context *)ids->buffer->context;
vk_buffer d_logits = nullptr;
size_t logits_buf_offset = 0;
vk_buffer d_weights = nullptr;
size_t weights_buf_offset = 0;
vk_buffer d_ids = nullptr;
size_t ids_buf_offset = 0;
bool logits_uma = false;
bool weights_uma = false;
bool ids_uma = false;
if (ctx->device->uma) {
ggml_vk_host_get(ctx->device, logits->data, d_logits, logits_buf_offset);
ggml_vk_host_get(ctx->device, weights->data, d_weights, weights_buf_offset);
ggml_vk_host_get(ctx->device, ids->data, d_ids, ids_buf_offset);
logits_uma = d_logits != nullptr;
weights_uma = d_weights != nullptr;
ids_uma = d_ids != nullptr;
}
if (!logits_uma) {
d_logits = logits_buf_ctx->dev_buffer;
logits_buf_offset = vk_tensor_offset(logits) + logits->view_offs;
GGML_ASSERT(d_logits != nullptr);
}
if (!weights_uma) {
d_weights = weights_buf_ctx->dev_buffer;
weights_buf_offset = vk_tensor_offset(weights) + weights->view_offs;
GGML_ASSERT(d_weights != nullptr);
}
if (!ids_uma) {
d_ids = ids_buf_ctx->dev_buffer;
ids_buf_offset = vk_tensor_offset(ids) + ids->view_offs;
GGML_ASSERT(d_ids != nullptr);
}
vk_op_topk_moe_push_constants pc;
pc.n_rows = n_rows;
pc.n_expert_used = n_expert_used;
GGML_ASSERT(n_expert_used <= n_experts);
const uint32_t rows_per_block = 4;
std::array<uint32_t, 3> elements = { CEIL_DIV(n_rows, rows_per_block), 1, 1 };
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
{
ggml_vk_subbuffer(ctx, d_logits, logits_buf_offset),
ggml_vk_subbuffer(ctx, d_weights, weights_buf_offset),
ggml_vk_subbuffer(ctx, d_ids, ids_buf_offset),
}, pc, elements);
}
static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool backprop, bool dryrun = false) { static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool backprop, bool dryrun = false) {
const int n_dims = ((int32_t *) dst->op_params)[1]; const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2]; const int mode = ((int32_t *) dst->op_params)[2];
@ -10861,6 +11079,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_CONV_2D_DW: case GGML_OP_CONV_2D_DW:
case GGML_OP_RWKV_WKV6: case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7: case GGML_OP_RWKV_WKV7:
case GGML_OP_SSM_SCAN:
case GGML_OP_SSM_CONV:
case GGML_OP_LEAKY_RELU: case GGML_OP_LEAKY_RELU:
case GGML_OP_FLASH_ATTN_EXT: case GGML_OP_FLASH_ATTN_EXT:
case GGML_OP_OPT_STEP_ADAMW: case GGML_OP_OPT_STEP_ADAMW:
@ -11008,11 +11228,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
ctx->unsynced_nodes_read.clear(); ctx->unsynced_nodes_read.clear();
ggml_vk_sync_buffers(ctx, compute_ctx); ggml_vk_sync_buffers(ctx, compute_ctx);
} }
// Add the last fused node and all fused source nodes to the unsynchronized list. // Add all fused nodes to the unsynchronized lists.
const ggml_tensor * last_node = cgraph->nodes[node_idx + ctx->num_additional_fused_ops];
ctx->unsynced_nodes_written.push_back(last_node);
for (int32_t i = 0; i < ctx->num_additional_fused_ops + 1; ++i) { for (int32_t i = 0; i < ctx->num_additional_fused_ops + 1; ++i) {
const ggml_tensor *cur_node = cgraph->nodes[node_idx + i]; const ggml_tensor *cur_node = cgraph->nodes[node_idx + i];
// Multiple outputs could be written, e.g. in topk_moe. Add them all to the list.
ctx->unsynced_nodes_written.push_back(cur_node);
for (uint32_t j = 0; j < GGML_MAX_SRC; ++j) { for (uint32_t j = 0; j < GGML_MAX_SRC; ++j) {
if (!cur_node->src[j]) { if (!cur_node->src[j]) {
continue; continue;
@ -11179,7 +11399,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
break; break;
case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX:
if (ctx->num_additional_fused_ops) {
ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx, dryrun);
} else {
ggml_vk_soft_max(ctx, compute_ctx, src0, src1, src2, node, dryrun); ggml_vk_soft_max(ctx, compute_ctx, src0, src1, src2, node, dryrun);
}
break; break;
case GGML_OP_SOFT_MAX_BACK: case GGML_OP_SOFT_MAX_BACK:
@ -11278,6 +11502,16 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
break; break;
case GGML_OP_SSM_SCAN:
ggml_vk_ssm_scan(ctx, compute_ctx, node, dryrun);
break;
case GGML_OP_SSM_CONV:
ggml_vk_ssm_conv(ctx, compute_ctx, node, dryrun);
break;
case GGML_OP_OPT_STEP_ADAMW: case GGML_OP_OPT_STEP_ADAMW:
ggml_vk_opt_step_adamw(ctx, compute_ctx, node, dryrun); ggml_vk_opt_step_adamw(ctx, compute_ctx, node, dryrun);
@ -11389,6 +11623,8 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
case GGML_OP_CONV_2D_DW: case GGML_OP_CONV_2D_DW:
case GGML_OP_RWKV_WKV6: case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7: case GGML_OP_RWKV_WKV7:
case GGML_OP_SSM_SCAN:
case GGML_OP_SSM_CONV:
case GGML_OP_LEAKY_RELU: case GGML_OP_LEAKY_RELU:
case GGML_OP_REPEAT: case GGML_OP_REPEAT:
case GGML_OP_REPEAT_BACK: case GGML_OP_REPEAT_BACK:
@ -11498,10 +11734,6 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
// Clean up after graph processing is done // Clean up after graph processing is done
static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) { static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) {
VK_LOG_DEBUG("ggml_vk_graph_cleanup()"); VK_LOG_DEBUG("ggml_vk_graph_cleanup()");
for (auto& buffer : ctx->gc.temp_buffers) {
ggml_vk_pool_free(ctx, buffer);
}
ctx->gc.temp_buffers.clear();
ctx->prealloc_y_last_pipeline_used = {}; ctx->prealloc_y_last_pipeline_used = {};
ctx->unsynced_nodes_written.clear(); ctx->unsynced_nodes_written.clear();
@ -11544,10 +11776,6 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) {
ggml_vk_destroy_buffer(ctx->prealloc_split_k); ggml_vk_destroy_buffer(ctx->prealloc_split_k);
ctx->prealloc_y_last_pipeline_used = nullptr; ctx->prealloc_y_last_pipeline_used = nullptr;
for (auto& buffer : ctx->buffer_pool) {
ggml_vk_destroy_buffer(buffer);
}
ctx->prealloc_size_x = 0; ctx->prealloc_size_x = 0;
ctx->prealloc_size_y = 0; ctx->prealloc_size_y = 0;
ctx->prealloc_size_split_k = 0; ctx->prealloc_size_split_k = 0;
@ -11988,6 +12216,120 @@ static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, st
return true; return true;
} }
static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph,
int node_idx, bool with_norm) {
if (with_norm) {
if (node_idx + (int)topk_moe_norm.size() > cgraph->n_nodes) {
return false;
}
for (size_t i = 0; i < topk_moe_norm.size(); ++i) {
if (cgraph->nodes[node_idx + i]->op != topk_moe_norm[i]) {
return false;
}
}
} else {
if (node_idx + (int)topk_moe.size() > cgraph->n_nodes) {
return false;
}
for (size_t i = 0; i < topk_moe.size(); ++i) {
if (cgraph->nodes[node_idx + i]->op != topk_moe[i]) {
return false;
}
}
}
const ggml_tensor * softmax = cgraph->nodes[node_idx + 0];
const ggml_tensor * weights = with_norm ? cgraph->nodes[node_idx + 8] : cgraph->nodes[node_idx + 4];
const float * op_params = (const float *)softmax->op_params;
float scale = op_params[0];
float max_bias = op_params[1];
if (!ggml_is_contiguous(softmax->src[0]) || !ggml_is_contiguous(weights)) {
return false;
}
if (scale != 1.0f || max_bias != 0.0f) {
return false;
}
// don't fuse when masks or sinks are present
if (softmax->src[1] || softmax->src[2]) {
return false;
}
const int n_expert = softmax->ne[0];
// n_expert must be a power of 2
if (!is_pow2(n_expert) || n_expert > (1 << (num_topk_moe_pipelines-1))) {
return false;
}
// Check that the nodes don't have any unexpected uses
const ggml_tensor * reshape1 = cgraph->nodes[node_idx + 1];
const ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
const ggml_tensor * view = cgraph->nodes[node_idx + 3];
const ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
const ggml_tensor * reshape5 = with_norm ? cgraph->nodes[node_idx + 5] : nullptr;
const ggml_tensor * sum_rows = with_norm ? cgraph->nodes[node_idx + 6] : nullptr;
const ggml_tensor * div = with_norm ? cgraph->nodes[node_idx + 7] : nullptr;
const ggml_tensor * reshape8 = with_norm ? cgraph->nodes[node_idx + 8] : nullptr;
// softmax is used by reshape and argsort
if (ggml_node_get_use_count(cgraph, node_idx) != 2 ||
reshape1->src[0] != softmax ||
argsort->src[0] != softmax) {
return false;
}
// reshape is used by get_rows
if (ggml_node_get_use_count(cgraph, node_idx + 1) != 1 ||
get_rows->src[0] != reshape1) {
return false;
}
// argsort is used by view
if (ggml_node_get_use_count(cgraph, node_idx + 2) != 1 ||
view->src[0] != argsort) {
return false;
}
// view is written (via argsort), we can skip checking it
if (with_norm) {
// get_rows is used by reshape
if (ggml_node_get_use_count(cgraph, node_idx + 4) != 1 ||
reshape5->src[0] != get_rows) {
return false;
}
// reshape is used by sum_rows and div
if (ggml_node_get_use_count(cgraph, node_idx + 5) != 2 ||
sum_rows->src[0] != reshape5 ||
div->src[0] != reshape5) {
return false;
}
// sum_rows is used by div
if (ggml_node_get_use_count(cgraph, node_idx + 6) != 1 ||
div->src[1] != sum_rows) {
return false;
}
// div/reshape are written
if (reshape8->src[0] != div) {
return false;
}
}
if (!ctx->device->subgroup_arithmetic ||
!ctx->device->subgroup_shuffle ||
!ctx->device->subgroup_require_full_support ||
ctx->device->disable_fusion) {
return false;
}
return true;
}
static uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx) { static uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx) {
const ggml_tensor *first_node = cgraph->nodes[node_idx]; const ggml_tensor *first_node = cgraph->nodes[node_idx];
@ -12063,6 +12405,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
ctx->num_additional_fused_ops = num_adds - 1; ctx->num_additional_fused_ops = num_adds - 1;
} else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { } else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
ctx->num_additional_fused_ops = 1; ctx->num_additional_fused_ops = 1;
} else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, true)) {
ctx->num_additional_fused_ops = topk_moe_norm.size() - 1;
} else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, false)) {
ctx->num_additional_fused_ops = topk_moe.size() - 1;
} }
} }
ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false); ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
@ -12160,6 +12506,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
ctx->num_additional_fused_ops = num_adds - 1; ctx->num_additional_fused_ops = num_adds - 1;
} else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { } else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
ctx->num_additional_fused_ops = 1; ctx->num_additional_fused_ops = 1;
} else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, true)) {
ctx->num_additional_fused_ops = topk_moe_norm.size() - 1;
} else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, false)) {
ctx->num_additional_fused_ops = topk_moe.size() - 1;
} }
} }
@ -12167,10 +12517,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5; bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5;
bool submit = (submitted_nodes >= nodes_per_submit) || bool submit = (submitted_nodes >= nodes_per_submit) ||
(mul_mat_bytes >= mul_mat_bytes_per_submit) || (mul_mat_bytes >= mul_mat_bytes_per_submit) ||
(i + ctx->num_additional_fused_ops == last_node) || (i + ctx->num_additional_fused_ops >= last_node) ||
(almost_ready && !ctx->almost_ready_fence_pending); (almost_ready && !ctx->almost_ready_fence_pending);
bool enqueued = ggml_vk_build_graph(ctx, cgraph, i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i + ctx->num_additional_fused_ops == last_node, almost_ready, submit); bool enqueued = ggml_vk_build_graph(ctx, cgraph, i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i + ctx->num_additional_fused_ops >= last_node, almost_ready, submit);
if (vk_perf_logger_enabled) { if (vk_perf_logger_enabled) {
if (ctx->compute_ctx.expired()) { if (ctx->compute_ctx.expired()) {
@ -12292,6 +12642,25 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
while (first_unused < graph->n_nodes) { while (first_unused < graph->n_nodes) {
std::vector<int> current_set; std::vector<int> current_set;
// Avoid reordering topk_moe_norm
if (first_unused + (int)topk_moe_norm.size() <= graph->n_nodes) {
bool is_topk_moe_norm = true;
for (size_t j = 0; j < topk_moe_norm.size(); ++j) {
if (graph->nodes[first_unused + j]->op != topk_moe_norm[j] || used[first_unused + j]) {
is_topk_moe_norm = false;
}
}
if (is_topk_moe_norm) {
for (size_t j = 0; j < topk_moe_norm.size(); ++j) {
new_order.push_back(graph->nodes[first_unused + j]);
used[first_unused + j] = true;
}
while (first_unused < graph->n_nodes && used[first_unused]) {
first_unused++;
}
continue;
}
}
// First, grab the next unused node. // First, grab the next unused node.
current_set.push_back(first_unused); current_set.push_back(first_unused);
@ -12797,6 +13166,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
} }
switch (op->src[1]->type) { switch (op->src[1]->type) {
case GGML_TYPE_F16: case GGML_TYPE_F16:
case GGML_TYPE_F32:
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_0:
// supported in scalar and coopmat2 paths // supported in scalar and coopmat2 paths
@ -13004,6 +13374,47 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_RWKV_WKV6: case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7: case GGML_OP_RWKV_WKV7:
return true; return true;
case GGML_OP_SSM_SCAN:
{
for (int i = 0; i < 6; i++) {
if (op->src[i] && ggml_is_quantized(op->src[i]->type)) {
return false;
}
}
if (op->src[6] && op->src[6]->type != GGML_TYPE_I32) {
return false;
}
if (op->src[0]->type != GGML_TYPE_F32 || op->type != GGML_TYPE_F32) {
return false;
}
const uint32_t d_state = op->src[0]->ne[0];
const uint32_t head_dim = op->src[0]->ne[1];
bool is_mamba2 = (op->src[3] && op->src[3]->nb[1] == sizeof(float));
if (!is_mamba2) {
return false;
}
if ((d_state != 128 && d_state != 256) || head_dim % 16 != 0) {
return false;
}
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
const vk_device& device = ggml_vk_get_device(ctx->device);
const uint32_t SPLIT_H = 16;
size_t stateC_size = SPLIT_H * d_state * sizeof(float);
if (stateC_size > device->properties.limits.maxComputeSharedMemorySize) {
return false;
}
return true;
}
case GGML_OP_SSM_CONV:
return true;
case GGML_OP_CONV_TRANSPOSE_1D: case GGML_OP_CONV_TRANSPOSE_1D:
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32; return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
case GGML_OP_CONV_2D: case GGML_OP_CONV_2D:
@ -13386,14 +13797,14 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
struct ggml_context * ggml_ctx = ggml_init(iparams); struct ggml_context * ggml_ctx = ggml_init(iparams);
std::array<struct ggml_tensor *, 6> src_clone = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}; std::array<struct ggml_tensor *, GGML_MAX_SRC> src_clone = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr};
std::array<size_t, 6> src_size = {0, 0, 0, 0, 0, 0}; std::array<size_t, GGML_MAX_SRC> src_size = {};
std::array<void *, 6> src_buffer = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}; std::array<void *, GGML_MAX_SRC> src_buffer = {};
const char * srci_name[6] = {"src0", "src1", "src2", "src3", "src4", "src5"}; const char * srci_name[GGML_MAX_SRC] = {"src0", "src1", "src2", "src3", "src4", "src5", "src6", "src7", "src8", "src9"};
struct ggml_tensor * tensor_clone = nullptr; struct ggml_tensor * tensor_clone = nullptr;
for (int i = 0; i < 6; i++) { for (int i = 0; i < GGML_MAX_SRC; i++) {
ggml_tensor * srci = tensor->src[i]; ggml_tensor * srci = tensor->src[i];
if (fused_rms_norm_mul) { if (fused_rms_norm_mul) {
rms_norm_idx = tensor->src[0]->op == GGML_OP_RMS_NORM ? 0 : 1; rms_norm_idx = tensor->src[0]->op == GGML_OP_RMS_NORM ? 0 : 1;
@ -13700,6 +14111,11 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
src_clone[2]); src_clone[2]);
} else if (tensor->op == GGML_OP_ADD_ID) { } else if (tensor->op == GGML_OP_ADD_ID) {
tensor_clone = ggml_add_id(ggml_ctx, src_clone[0], src_clone[1], src_clone[2]); tensor_clone = ggml_add_id(ggml_ctx, src_clone[0], src_clone[1], src_clone[2]);
} else if (tensor->op == GGML_OP_SSM_SCAN) {
tensor_clone = ggml_ssm_scan(ggml_ctx, src_clone[0], src_clone[1], src_clone[2],
src_clone[3], src_clone[4], src_clone[5], src_clone[6]);
} else if (tensor->op == GGML_OP_SSM_CONV) {
tensor_clone = ggml_ssm_conv(ggml_ctx, src_clone[0], src_clone[1]);
} }
else { else {
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl; std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
@ -13721,7 +14137,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
memcpy(comp_result, tensor_clone->data, comp_size); memcpy(comp_result, tensor_clone->data, comp_size);
memcpy(comp_nb, tensor_clone->nb, sizeof(size_t) * GGML_MAX_DIMS); memcpy(comp_nb, tensor_clone->nb, sizeof(size_t) * GGML_MAX_DIMS);
for (int i = 0; i < 6; i++) { for (int i = 0; i < GGML_MAX_SRC; i++) {
if (src_buffer[i] != nullptr) { if (src_buffer[i] != nullptr) {
free(src_buffer[i]); free(src_buffer[i]);
} }

View File

@ -1,6 +1,18 @@
#include "types.glsl" #include "types.glsl"
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufF32 {
vec4 block;
};
float16_t dequantFuncF32(const in decodeBufF32 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const vec4 v = bl.block;
const uint idx = coordInBlock[1];
const f16vec4 vf16 = f16vec4(v);
return vf16[idx];
}
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ4_0 { layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ4_0 {
block_q4_0_packed16 block; block_q4_0_packed16 block;
}; };
@ -717,4 +729,6 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords
#define dequantFuncA dequantFuncIQ4_NL #define dequantFuncA dequantFuncIQ4_NL
#elif defined(DATA_A_MXFP4) #elif defined(DATA_A_MXFP4)
#define dequantFuncA dequantFuncMXFP4 #define dequantFuncA dequantFuncMXFP4
#elif defined(DATA_A_F32)
#define dequantFuncA dequantFuncF32
#endif #endif

View File

@ -345,7 +345,7 @@ void main() {
float Lfrcp[Br]; float Lfrcp[Br];
[[unroll]] for (uint32_t r = 0; r < Br; ++r) { [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Lfrcp[r] = 1.0 / Lf[r]; Lfrcp[r] = (Lf[r] == 0.0) ? 0.0 : (1.0 / Lf[r]);
} }
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {

View File

@ -64,13 +64,31 @@ layout (binding = 4) readonly buffer S {float data_s[];};
layout (binding = 5) writeonly buffer O {D_TYPE data_o[];}; layout (binding = 5) writeonly buffer O {D_TYPE data_o[];};
#if defined(A_TYPE_PACKED16)
#define BINDING_IDX_K 0 #define BINDING_IDX_K 0
#define BINDING_IDX_V 1 #define BINDING_IDX_V 1
#if defined(DATA_A_F32)
layout (binding = 1) readonly buffer K_PACKED {vec4 k_data_packed[];} k_packed;
layout (binding = 2) readonly buffer V_PACKED {vec4 v_data_packed[];} v_packed;
#elif defined(A_TYPE_PACKED16)
layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16[];} k_packed; layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16[];} k_packed;
layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed; layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed;
#endif #endif
#if defined(DATA_A_F32)
#undef BLOCK_SIZE
#define BLOCK_SIZE 4
#define BLOCK_BYTE_SIZE 16
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
// iqs is currently always zero in the flash attention shaders
if (binding_idx == BINDING_IDX_K) {
return k_packed.k_data_packed[a_offset + ib];
} else {
return v_packed.v_data_packed[a_offset + ib];
}
}
#endif
#if defined(DATA_A_Q4_0) #if defined(DATA_A_Q4_0)
#define BLOCK_BYTE_SIZE 18 #define BLOCK_BYTE_SIZE 18

View File

@ -380,7 +380,7 @@ void main() {
float Lfrcp[rows_per_thread]; float Lfrcp[rows_per_thread];
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Lfrcp[r] = 1.0 / Lf[r]; Lfrcp[r] = (Lf[r] == 0.0) ? 0.0 : (1.0 / Lf[r]);
} }
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {

View File

@ -121,7 +121,11 @@ void main() {
const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF); const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF);
L = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0); L = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
#if defined(ACC_TYPE_MAX)
M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(-ACC_TYPE_MAX / ACC_TYPE(2));
#else
M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(NEG_FLT_MAX_OVER_2); M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(NEG_FLT_MAX_OVER_2);
#endif
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> slopeMat = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(1.0); coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> slopeMat = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(1.0);
@ -294,7 +298,7 @@ void main() {
[[unroll]] [[unroll]]
for (int k = 0; k < Ldiag.length(); ++k) { for (int k = 0; k < Ldiag.length(); ++k) {
Ldiag[k] = ACC_TYPE(1.0) / Ldiag[k]; Ldiag[k] = (Ldiag[k] == 0.0) ? ACC_TYPE(0.0) : (ACC_TYPE(1.0) / Ldiag[k]);
} }
O = Ldiag*O; O = Ldiag*O;

View File

@ -91,7 +91,7 @@ void main() {
L = L*ms + vs; L = L*ms + vs;
} }
L = 1.0 / L; L = (L == 0.0) ? 0.0 : 1.0 / L;
// D dimension is split across workgroups in the y dimension // D dimension is split across workgroups in the y dimension
uint d = tid + gl_WorkGroupID.y * BLOCK_SIZE; uint d = tid + gl_WorkGroupID.y * BLOCK_SIZE;

View File

@ -313,12 +313,12 @@ void main() {
sums[i] = coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0f); sums[i] = coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0f);
} }
#else #else
ACC_TYPE sums[WMITER * TM * WNITER * TN]; ACC_TYPE_VEC2 sums[WMITER * TM * WNITER * TN/2];
FLOAT_TYPE_VEC2 cache_a[WMITER * TM]; FLOAT_TYPE_VEC2 cache_a[WMITER * TM];
FLOAT_TYPE_VEC2 cache_b[TN]; FLOAT_TYPE_VEC2 cache_b;
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) { [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN/2; i++) {
sums[i] = ACC_TYPE(0.0f); sums[i] = ACC_TYPE_VEC2(0.0f, 0.0f);
} }
#endif #endif
@ -360,20 +360,22 @@ void main() {
cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + i]; cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + i];
} }
} }
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
[[unroll]] for (uint j = 0; j < TN; j++) { [[unroll]] for (uint cc = 0; cc < TN; cc++) {
cache_b[j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * SHMEM_STRIDE + i]; cache_b = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + cc) * SHMEM_STRIDE + i];
}
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
[[unroll]] for (uint cc = 0; cc < TN; cc++) { [[unroll]] for (uint cr = 0; cr < TM / 2; cr++) {
[[unroll]] for (uint cr = 0; cr < TM; cr++) { // [WNITER][TN][WMITER][TM / 2] -> [wsic][cc][wsir][cr]
const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr; const uint sums_idx = (wsic * TN + cc) * WMITER * (TM / 2) + wsir * (TM / 2) + cr;
sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr].x), ACC_TYPE(cache_b[cc].x), fma(ACC_TYPE(cache_a[wsir * TM + cr].y), ACC_TYPE(cache_b[cc].y), sums[sums_idx])); sums[sums_idx].x = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].y), ACC_TYPE(cache_b.y), sums[sums_idx].x));
sums[sums_idx].y = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].y), ACC_TYPE(cache_b.y), sums[sums_idx].y));
} }
} }
} }
} }
} }
#endif #endif
@ -388,8 +390,9 @@ void main() {
} }
} }
#else #else
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) { [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN/2; i++) {
sums[i] = clamp(sums[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); sums[i].x = clamp(sums[i].x, -ACC_TYPE_MAX, ACC_TYPE_MAX);
sums[i].y = clamp(sums[i].y, -ACC_TYPE_MAX, ACC_TYPE_MAX);
} }
#endif #endif
#endif #endif
@ -463,14 +466,21 @@ void main() {
const u16vec2 row_idx = row_ids[row_i - ic * BN]; const u16vec2 row_idx = row_ids[row_i - ic * BN];
#endif // MUL_MAT_ID #endif // MUL_MAT_ID
[[unroll]] for (uint cr = 0; cr < TM; cr++) { [[unroll]] for (uint cr = 0; cr < TM / 2; cr++) {
const uint sums_idx = (wsic * TN + cc) * WMITER * (TM / 2) + wsir * (TM / 2) + cr;
#ifdef MUL_MAT_ID #ifdef MUL_MAT_ID
if (dr_warp + cr < p.M) { if (dr_warp + 2 * cr < p.M) {
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + 2 * cr] = D_TYPE(sums[sums_idx].x);
}
if (dr_warp + 2 * cr + 1 < p.M) {
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + 2 * cr + 1] = D_TYPE(sums[sums_idx].y);
} }
#else #else
if (dr_warp + cr < p.M && dc_warp + cc < p.N) { if (dr_warp + 2 * cr < p.M && dc_warp + cc < p.N) {
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + 2 * cr] = D_TYPE(sums[sums_idx].x);
}
if (dr_warp + 2 * cr + 1 < p.M && dc_warp + cc < p.N) {
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + 2 * cr + 1] = D_TYPE(sums[sums_idx].y);
} }
#endif // MUL_MAT_ID #endif // MUL_MAT_ID
} }

View File

@ -0,0 +1,44 @@
#version 450
#extension GL_EXT_control_flow_attributes : require
#include "types.glsl"
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout(binding = 0) readonly buffer Src0 { float src0[]; };
layout(binding = 1) readonly buffer Src1 { float src1[]; };
layout(binding = 2) buffer Dst { float dst[]; };
layout(push_constant) uniform PushConstants {
uint nb01; uint nb02;
uint nb11;
uint dst_nb0; uint dst_nb1; uint dst_nb2;
uint nc; uint ncs; uint nr; uint n_t; uint n_s;
};
void main() {
const uint global_thread_id = gl_GlobalInvocationID.x;
const uint i2 = gl_WorkGroupID.y;
const uint i3 = gl_WorkGroupID.z;
if (global_thread_id >= nr || i2 >= n_t || i3 >= n_s) {
return;
}
const uint i1 = global_thread_id;
const uint src0_base = i3 * (nb02 / 4) + i2 + i1 * (nb01 / 4);
const uint src1_base = i1 * (nb11 / 4);
const uint dst_idx = i3 * (dst_nb2 / 4) + i2 * (dst_nb1 / 4) + i1;
float sum = 0.0;
[[unroll]] for (uint i0 = 0; i0 < nc; i0++) {
const uint src0_idx = src0_base + i0;
const uint src1_idx = src1_base + i0;
sum += src0[src0_idx] * src1[src1_idx];
}
dst[dst_idx] = sum;
}

View File

@ -0,0 +1,140 @@
#version 450
#extension GL_EXT_control_flow_attributes : require
#if USE_SUBGROUP_ADD
#extension GL_KHR_shader_subgroup_arithmetic : enable
#endif
#include "types.glsl"
layout(constant_id = 0) const uint D_STATE = 128;
layout(constant_id = 1) const uint SUBGROUP_SIZE = 32;
layout(constant_id = 2) const uint SPLIT_H = 16;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout(binding = 0) readonly buffer Src0 { float s0[]; };
layout(binding = 1) readonly buffer Src1 { float x[]; };
layout(binding = 2) readonly buffer Src2 { float dt[]; };
layout(binding = 3) readonly buffer Src3 { float A[]; };
layout(binding = 4) readonly buffer Src4 { float B[]; };
layout(binding = 5) readonly buffer Src5 { float C[]; };
layout(binding = 6) readonly buffer Src6 { int ids[]; };
layout(binding = 7) buffer Dst { float d[]; };
layout(push_constant) uniform PushConstants {
uint nb02; uint nb03; uint nb12; uint nb13;
uint nb21; uint nb22; uint nb31;
uint nb42; uint nb43; uint nb52; uint nb53;
uint s_off;
uint n_head;
uint d_head;
uint n_group;
uint n_tok;
};
float softplus(float x) {
if (x <= 20.0) {
return log(1.0 + exp(x));
} else {
return x;
}
}
shared float stateC[SPLIT_H * D_STATE];
void main() {
const uint tid = gl_LocalInvocationID.x;
const uint head_idx = (gl_WorkGroupID.x * SPLIT_H) / d_head;
const uint head_off = ((gl_WorkGroupID.x * SPLIT_H) % d_head) * 4;
const uint seq_idx = gl_WorkGroupID.y;
const uint group_off = (head_idx / (n_head / n_group)) * D_STATE * 4;
const uint s0_base_idx = (uint(ids[seq_idx]) * nb03 + head_idx * nb02 + head_off * D_STATE) / 4;
const uint x_base_idx = (seq_idx * nb13 + gl_WorkGroupID.x * SPLIT_H * 4) / 4;
const uint dt_base_idx = (seq_idx * nb22 + head_idx * 4) / 4;
const uint A_base_idx = (head_idx * nb31) / 4;
const uint B_base_idx = (seq_idx * nb43 + group_off) / 4;
const uint C_base_idx = (seq_idx * nb53 + group_off) / 4;
const uint y_base_idx = seq_idx * n_tok * n_head * d_head + gl_WorkGroupID.x * SPLIT_H;
const uint s_base_idx = (s_off + seq_idx * nb03 + head_idx * nb02 + head_off * D_STATE) / 4;
const uint stride_x = nb12 / 4;
const uint stride_dt = nb21 / 4;
const uint stride_B = nb42 / 4;
const uint stride_C = nb52 / 4;
const uint stride_y = n_head * d_head;
float state[SPLIT_H];
[[unroll]] for (uint j = 0; j < SPLIT_H; j++) {
state[j] = s0[s0_base_idx + j * D_STATE + tid];
}
for (uint i = 0; i < n_tok; i++) {
const float dt_soft_plus = softplus(dt[dt_base_idx + i * stride_dt]);
const float dA = exp(dt_soft_plus * A[A_base_idx]);
const float B_val = B[B_base_idx + i * stride_B + tid];
const float C_val = C[C_base_idx + i * stride_C + tid];
[[unroll]] for (uint j = 0; j < SPLIT_H; j++) {
const float x_dt = x[x_base_idx + i * stride_x + j] * dt_soft_plus;
state[j] = (state[j] * dA) + (B_val * x_dt);
stateC[j * D_STATE + tid] = state[j] * C_val;
}
barrier();
[[unroll]]
for (uint w = D_STATE / 2; w >= SUBGROUP_SIZE; w >>= 1) {
[[unroll]] for (uint j = 0; j < (w * SPLIT_H + D_STATE - 1) / D_STATE; j++) {
const uint k = (tid % w) + (D_STATE * (tid / w)) + j * D_STATE * (D_STATE / w);
if (k < SPLIT_H * D_STATE && (k + w) < SPLIT_H * D_STATE) {
stateC[k] += stateC[k + w];
}
}
barrier();
}
[[unroll]] for (uint j = 0; j < max(1, SPLIT_H / (D_STATE / SUBGROUP_SIZE)); j++) {
const uint idx = (tid % SUBGROUP_SIZE) +
D_STATE * (tid / SUBGROUP_SIZE) +
j * D_STATE * (D_STATE / SUBGROUP_SIZE);
const uint max_idx = SUBGROUP_SIZE - 1 +
D_STATE * ((D_STATE - 1) / SUBGROUP_SIZE) +
j * D_STATE * (D_STATE / SUBGROUP_SIZE);
if (idx < SPLIT_H * D_STATE ||
max_idx < SPLIT_H * D_STATE) {
float sc;
#if USE_SUBGROUP_ADD
sc = stateC[idx];
sc = subgroupAdd(sc);
#else
[[unroll]] for (uint offset = SUBGROUP_SIZE / 2; offset > 0; offset >>= 1) {
if (idx + offset < SPLIT_H * D_STATE) {
stateC[idx] += stateC[idx + offset];
}
barrier();
}
if (tid % SUBGROUP_SIZE == 0) {
sc = stateC[idx];
}
#endif
if (tid % SUBGROUP_SIZE == 0) {
const uint k = tid / SUBGROUP_SIZE + j * (D_STATE / SUBGROUP_SIZE);
d[y_base_idx + i * stride_y + k] = sc;
}
}
}
barrier();
}
[[unroll]] for (uint j = 0; j < SPLIT_H; j++) {
d[s_base_idx + j * D_STATE + tid] = state[j];
}
}

View File

@ -0,0 +1,139 @@
#version 450
#extension GL_EXT_control_flow_attributes : require
#extension GL_KHR_shader_subgroup_basic : enable
#extension GL_KHR_shader_subgroup_arithmetic : enable
#extension GL_KHR_shader_subgroup_shuffle : enable
#include "types.glsl"
layout (push_constant) uniform parameter
{
uint n_rows;
uint n_expert_used;
};
layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
layout(constant_id = 0) const uint WARP_SIZE = 32;
layout(constant_id = 1) const uint n_experts = 512;
layout(constant_id = 2) const bool with_norm = true;
const uint experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1;
layout (binding = 0, std430) readonly buffer Logits {float logits[];};
layout (binding = 1, std430) writeonly buffer Weights {float weights[];};
layout (binding = 2, std430) writeonly buffer Ids {uint ids[];};
void main() {
const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_LocalInvocationID.y;
if (row >= n_rows) {
return;
}
const uint logits_offset = n_experts * row;
const uint weights_offset = n_expert_used * row;
const uint ids_offset = n_experts * row;
float logits_r[experts_per_thread];
const float INFINITY = 1.0 / 0.0;
[[unroll]]
for (uint i = 0; i < n_experts; i += WARP_SIZE) {
const uint expert = i + gl_LocalInvocationID.x;
logits_r[i / WARP_SIZE] = n_experts % WARP_SIZE == 0 || expert < n_experts ? logits[logits_offset + expert] : -INFINITY;
}
float max_val = logits_r[0];
[[unroll]]
for (int i = 1; i < experts_per_thread; i++) {
const float val = logits_r[i];
max_val = max(val, max_val);
}
max_val = subgroupMax(max_val);
float wt[experts_per_thread];
float tmp = 0.f;
[[unroll]]
for (int i = 0; i < experts_per_thread; i++) {
const float val = logits_r[i];
wt[i] = exp(val - max_val);
tmp += wt[i];
}
tmp = subgroupAdd(tmp);
const float inv_sum = 1.0f / tmp;
[[unroll]]
for (int i = 0; i < experts_per_thread; i++) {
wt[i] = wt[i] * inv_sum;
}
// at this point, each thread holds a portion of softmax,
// we do the argmax reduce over n_expert_used, each time marking
// the expert weight as -inf to exclude from the next iteration
float wt_sum = 0.f;
float output_weights[experts_per_thread];
for (int k = 0; k < n_expert_used; k++) {
float max_val = wt[0];
uint max_expert = gl_LocalInvocationID.x;
[[unroll]]
for (int i = 1; i < experts_per_thread; i++) {
const uint expert = gl_LocalInvocationID.x + i * WARP_SIZE;
if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) {
max_val = wt[i];
max_expert = expert;
}
}
[[unroll]]
for (uint mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
const float val = subgroupShuffleXor(max_val, mask);
const uint expert = subgroupShuffleXor(max_expert, mask);
if (val > max_val || (val == max_val && expert < max_expert)) {
max_val = val;
max_expert = expert;
}
}
if ((k & (WARP_SIZE - 1)) == gl_LocalInvocationID.x) {
output_weights[k / WARP_SIZE] = max_val;
}
if ((max_expert & (WARP_SIZE - 1)) == gl_LocalInvocationID.x) {
wt[max_expert / WARP_SIZE] = -INFINITY;
ids[ids_offset + k] = max_expert;
if (with_norm) {
wt_sum += max_val;
}
}
}
if (with_norm) {
wt_sum = subgroupAdd(wt_sum);
const float inv_sum = 1.0f / wt_sum;
[[unroll]]
for (uint i = 0; i < experts_per_thread; ++i) {
output_weights[i] *= inv_sum;
}
}
[[unroll]]
for (uint i = 0; i < experts_per_thread; ++i) {
uint idx = i * WARP_SIZE + gl_LocalInvocationID.x;
if (idx < n_expert_used) {
weights[weights_offset + idx] = output_weights[i];
}
}
}

View File

@ -611,9 +611,6 @@ void process_shaders() {
} }
for (const auto& tname : type_names) { for (const auto& tname : type_names) {
if (tname == "f32") {
continue;
}
if (tname == "bf16") continue; if (tname == "bf16") continue;
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
@ -630,7 +627,7 @@ void process_shaders() {
if (tname == "f16") { if (tname == "f16") {
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"COOPMAT", "1"}}), true, true, false, f16acc); merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"COOPMAT", "1"}}), true, true, false, f16acc);
} else if (tname == "q4_0" || tname == "q8_0") { } else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") {
std::string data_a_key = "DATA_A_" + to_uppercase(tname); std::string data_a_key = "DATA_A_" + to_uppercase(tname);
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc); merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc);
@ -639,7 +636,7 @@ void process_shaders() {
if (tname == "f16") { if (tname == "f16") {
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, false, f16acc); merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, false, f16acc);
} else if (tname == "q4_0" || tname == "q8_0") { } else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") {
std::string data_a_key = "DATA_A_" + to_uppercase(tname); std::string data_a_key = "DATA_A_" + to_uppercase(tname);
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc); merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc);
@ -919,6 +916,13 @@ void process_shaders() {
string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "0"}}); string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "0"}});
string_to_spv("multi_add_rms_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "1"}}); string_to_spv("multi_add_rms_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "1"}});
string_to_spv("ssm_scan_f32", "ssm_scan.comp", {{"A_TYPE", "float"}});
string_to_spv("ssm_scan_subgroup_f32", "ssm_scan.comp", {{"A_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}});
string_to_spv("ssm_conv_f32", "ssm_conv.comp", {{"A_TYPE", "float"}});
string_to_spv("topk_moe_f32", "topk_moe.comp", {});
for (auto &c : compiles) { for (auto &c : compiles) {
c.wait(); c.wait();
} }
@ -962,7 +966,7 @@ void write_output_files() {
} }
std::string suffixes[2] = {"_f32", "_f16"}; std::string suffixes[2] = {"_f32", "_f16"};
for (auto op : {"add", "sub", "mul", "div", "add_rms"}) { for (std::string op : {"add", "sub", "mul", "div", "add_rms"}) {
hdr << "extern const void * " << op << "_data[2][2][2][2];\n"; hdr << "extern const void * " << op << "_data[2][2][2][2];\n";
hdr << "extern const uint64_t " << op << "_len[2][2][2][2];\n"; hdr << "extern const uint64_t " << op << "_len[2][2][2][2];\n";

Some files were not shown because too many files have changed in this diff Show More