mirror of https://github.com/ollama/ollama
ggml update to b6840 (#12791)
This commit is contained in:
parent
c4ba257c64
commit
544b6739dd
|
|
@ -1,6 +1,6 @@
|
||||||
UPSTREAM=https://github.com/ggml-org/llama.cpp.git
|
UPSTREAM=https://github.com/ggml-org/llama.cpp.git
|
||||||
WORKDIR=llama/vendor
|
WORKDIR=llama/vendor
|
||||||
FETCH_HEAD=7049736b2dd9011bf819e298b844ebbc4b5afdc9
|
FETCH_HEAD=3cfa9c3f125763305b4226bc032f1954f08990dc
|
||||||
|
|
||||||
.PHONY: help
|
.PHONY: help
|
||||||
help:
|
help:
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
int LLAMA_BUILD_NUMBER = 0;
|
int LLAMA_BUILD_NUMBER = 0;
|
||||||
char const *LLAMA_COMMIT = "7049736b2dd9011bf819e298b844ebbc4b5afdc9";
|
char const *LLAMA_COMMIT = "3cfa9c3f125763305b4226bc032f1954f08990dc";
|
||||||
char const *LLAMA_COMPILER = "";
|
char const *LLAMA_COMPILER = "";
|
||||||
char const *LLAMA_BUILD_TARGET = "";
|
char const *LLAMA_BUILD_TARGET = "";
|
||||||
|
|
|
||||||
|
|
@ -41,9 +41,9 @@ static std::string build_repetition(const std::string & item_rule, int min_items
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void _build_min_max_int(int min_value, int max_value, std::stringstream & out, int decimals_left = 16, bool top_level = true) {
|
static void _build_min_max_int(int64_t min_value, int64_t max_value, std::stringstream & out, int decimals_left = 16, bool top_level = true) {
|
||||||
auto has_min = min_value != std::numeric_limits<int>::min();
|
auto has_min = min_value != std::numeric_limits<int64_t>::min();
|
||||||
auto has_max = max_value != std::numeric_limits<int>::max();
|
auto has_max = max_value != std::numeric_limits<int64_t>::max();
|
||||||
|
|
||||||
auto digit_range = [&](char from, char to) {
|
auto digit_range = [&](char from, char to) {
|
||||||
out << "[";
|
out << "[";
|
||||||
|
|
@ -159,7 +159,7 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
|
||||||
if (has_min) {
|
if (has_min) {
|
||||||
if (min_value < 0) {
|
if (min_value < 0) {
|
||||||
out << "\"-\" (";
|
out << "\"-\" (";
|
||||||
_build_min_max_int(std::numeric_limits<int>::min(), -min_value, out, decimals_left, /* top_level= */ false);
|
_build_min_max_int(std::numeric_limits<int64_t>::min(), -min_value, out, decimals_left, /* top_level= */ false);
|
||||||
out << ") | [0] | [1-9] ";
|
out << ") | [0] | [1-9] ";
|
||||||
more_digits(0, decimals_left - 1);
|
more_digits(0, decimals_left - 1);
|
||||||
} else if (min_value == 0) {
|
} else if (min_value == 0) {
|
||||||
|
|
@ -194,7 +194,7 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
|
||||||
}
|
}
|
||||||
digit_range(c, c);
|
digit_range(c, c);
|
||||||
out << " (";
|
out << " (";
|
||||||
_build_min_max_int(std::stoi(min_s.substr(1)), std::numeric_limits<int>::max(), out, less_decimals, /* top_level= */ false);
|
_build_min_max_int(std::stoll(min_s.substr(1)), std::numeric_limits<int64_t>::max(), out, less_decimals, /* top_level= */ false);
|
||||||
out << ")";
|
out << ")";
|
||||||
if (c < '9') {
|
if (c < '9') {
|
||||||
out << " | ";
|
out << " | ";
|
||||||
|
|
@ -216,7 +216,7 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
|
||||||
_build_min_max_int(0, max_value, out, decimals_left, /* top_level= */ true);
|
_build_min_max_int(0, max_value, out, decimals_left, /* top_level= */ true);
|
||||||
} else {
|
} else {
|
||||||
out << "\"-\" (";
|
out << "\"-\" (";
|
||||||
_build_min_max_int(-max_value, std::numeric_limits<int>::max(), out, decimals_left, /* top_level= */ false);
|
_build_min_max_int(-max_value, std::numeric_limits<int64_t>::max(), out, decimals_left, /* top_level= */ false);
|
||||||
out << ")";
|
out << ")";
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
|
|
@ -925,17 +925,17 @@ public:
|
||||||
int max_len = schema.contains("maxLength") ? schema["maxLength"].get<int>() : std::numeric_limits<int>::max();
|
int max_len = schema.contains("maxLength") ? schema["maxLength"].get<int>() : std::numeric_limits<int>::max();
|
||||||
return _add_rule(rule_name, "\"\\\"\" " + build_repetition(char_rule, min_len, max_len) + " \"\\\"\" space");
|
return _add_rule(rule_name, "\"\\\"\" " + build_repetition(char_rule, min_len, max_len) + " \"\\\"\" space");
|
||||||
} else if (schema_type == "integer" && (schema.contains("minimum") || schema.contains("exclusiveMinimum") || schema.contains("maximum") || schema.contains("exclusiveMaximum"))) {
|
} else if (schema_type == "integer" && (schema.contains("minimum") || schema.contains("exclusiveMinimum") || schema.contains("maximum") || schema.contains("exclusiveMaximum"))) {
|
||||||
int min_value = std::numeric_limits<int>::min();
|
int64_t min_value = std::numeric_limits<int64_t>::min();
|
||||||
int max_value = std::numeric_limits<int>::max();
|
int64_t max_value = std::numeric_limits<int64_t>::max();
|
||||||
if (schema.contains("minimum")) {
|
if (schema.contains("minimum")) {
|
||||||
min_value = schema["minimum"].get<int>();
|
min_value = schema["minimum"].get<int64_t>();
|
||||||
} else if (schema.contains("exclusiveMinimum")) {
|
} else if (schema.contains("exclusiveMinimum")) {
|
||||||
min_value = schema["exclusiveMinimum"].get<int>() + 1;
|
min_value = schema["exclusiveMinimum"].get<int64_t>() + 1;
|
||||||
}
|
}
|
||||||
if (schema.contains("maximum")) {
|
if (schema.contains("maximum")) {
|
||||||
max_value = schema["maximum"].get<int>();
|
max_value = schema["maximum"].get<int64_t>();
|
||||||
} else if (schema.contains("exclusiveMaximum")) {
|
} else if (schema.contains("exclusiveMaximum")) {
|
||||||
max_value = schema["exclusiveMaximum"].get<int>() - 1;
|
max_value = schema["exclusiveMaximum"].get<int64_t>() - 1;
|
||||||
}
|
}
|
||||||
std::stringstream out;
|
std::stringstream out;
|
||||||
out << "(";
|
out << "(";
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@
|
||||||
#include <map>
|
#include <map>
|
||||||
|
|
||||||
static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||||
|
{ LLM_ARCH_CLIP, "clip" }, // dummy, only used by llama-quantize
|
||||||
{ LLM_ARCH_LLAMA, "llama" },
|
{ LLM_ARCH_LLAMA, "llama" },
|
||||||
{ LLM_ARCH_LLAMA4, "llama4" },
|
{ LLM_ARCH_LLAMA4, "llama4" },
|
||||||
{ LLM_ARCH_DECI, "deci" },
|
{ LLM_ARCH_DECI, "deci" },
|
||||||
|
|
@ -85,6 +86,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||||
{ LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
|
{ LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
|
||||||
{ LLM_ARCH_PLM, "plm" },
|
{ LLM_ARCH_PLM, "plm" },
|
||||||
{ LLM_ARCH_BAILINGMOE, "bailingmoe" },
|
{ LLM_ARCH_BAILINGMOE, "bailingmoe" },
|
||||||
|
{ LLM_ARCH_BAILINGMOE2, "bailingmoe2" },
|
||||||
{ LLM_ARCH_DOTS1, "dots1" },
|
{ LLM_ARCH_DOTS1, "dots1" },
|
||||||
{ LLM_ARCH_ARCEE, "arcee" },
|
{ LLM_ARCH_ARCEE, "arcee" },
|
||||||
{ LLM_ARCH_ERNIE4_5, "ernie4_5" },
|
{ LLM_ARCH_ERNIE4_5, "ernie4_5" },
|
||||||
|
|
@ -135,6 +137,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||||
{ LLM_KV_EXPERT_COUNT, "%s.expert_count" },
|
{ LLM_KV_EXPERT_COUNT, "%s.expert_count" },
|
||||||
{ LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" },
|
{ LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" },
|
||||||
{ LLM_KV_EXPERT_SHARED_COUNT, "%s.expert_shared_count" },
|
{ LLM_KV_EXPERT_SHARED_COUNT, "%s.expert_shared_count" },
|
||||||
|
{ LLM_KV_EXPERT_GROUP_COUNT, "%s.expert_group_count" },
|
||||||
|
{ LLM_KV_EXPERT_GROUP_USED_COUNT, "%s.expert_group_used_count" },
|
||||||
{ LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" },
|
{ LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" },
|
||||||
{ LLM_KV_EXPERT_WEIGHTS_NORM, "%s.expert_weights_norm" },
|
{ LLM_KV_EXPERT_WEIGHTS_NORM, "%s.expert_weights_norm" },
|
||||||
{ LLM_KV_EXPERT_GATING_FUNC, "%s.expert_gating_func" },
|
{ LLM_KV_EXPERT_GATING_FUNC, "%s.expert_gating_func" },
|
||||||
|
|
@ -277,6 +281,10 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||||
};
|
};
|
||||||
|
|
||||||
static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_NAMES = {
|
static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_NAMES = {
|
||||||
|
{
|
||||||
|
LLM_ARCH_CLIP,
|
||||||
|
{},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
LLM_ARCH_LLAMA,
|
LLM_ARCH_LLAMA,
|
||||||
{
|
{
|
||||||
|
|
@ -1961,6 +1969,38 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||||
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
LLM_ARCH_BAILINGMOE2,
|
||||||
|
{
|
||||||
|
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||||
|
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||||
|
{ LLM_TENSOR_OUTPUT, "output" },
|
||||||
|
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||||
|
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
||||||
|
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
||||||
|
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
|
||||||
|
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||||
|
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||||
|
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
|
||||||
|
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||||
|
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||||
|
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||||
|
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||||
|
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
||||||
|
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||||
|
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||||
|
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
|
||||||
|
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
|
||||||
|
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
||||||
|
{ LLM_TENSOR_NEXTN_EH_PROJ, "blk.%d.nextn.eh_proj" },
|
||||||
|
{ LLM_TENSOR_NEXTN_EMBED_TOKENS, "blk.%d.nextn.embed_tokens" },
|
||||||
|
{ LLM_TENSOR_NEXTN_ENORM, "blk.%d.nextn.enorm" },
|
||||||
|
{ LLM_TENSOR_NEXTN_HNORM, "blk.%d.nextn.hnorm" },
|
||||||
|
{ LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "blk.%d.nextn.shared_head_head" },
|
||||||
|
{ LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "blk.%d.nextn.shared_head_norm" },
|
||||||
|
{ LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" },
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
LLM_ARCH_DOTS1,
|
LLM_ARCH_DOTS1,
|
||||||
{
|
{
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@
|
||||||
//
|
//
|
||||||
|
|
||||||
enum llm_arch {
|
enum llm_arch {
|
||||||
|
LLM_ARCH_CLIP,
|
||||||
LLM_ARCH_LLAMA,
|
LLM_ARCH_LLAMA,
|
||||||
LLM_ARCH_LLAMA4,
|
LLM_ARCH_LLAMA4,
|
||||||
LLM_ARCH_DECI,
|
LLM_ARCH_DECI,
|
||||||
|
|
@ -89,6 +90,7 @@ enum llm_arch {
|
||||||
LLM_ARCH_WAVTOKENIZER_DEC,
|
LLM_ARCH_WAVTOKENIZER_DEC,
|
||||||
LLM_ARCH_PLM,
|
LLM_ARCH_PLM,
|
||||||
LLM_ARCH_BAILINGMOE,
|
LLM_ARCH_BAILINGMOE,
|
||||||
|
LLM_ARCH_BAILINGMOE2,
|
||||||
LLM_ARCH_DOTS1,
|
LLM_ARCH_DOTS1,
|
||||||
LLM_ARCH_ARCEE,
|
LLM_ARCH_ARCEE,
|
||||||
LLM_ARCH_ERNIE4_5,
|
LLM_ARCH_ERNIE4_5,
|
||||||
|
|
@ -139,6 +141,8 @@ enum llm_kv {
|
||||||
LLM_KV_EXPERT_COUNT,
|
LLM_KV_EXPERT_COUNT,
|
||||||
LLM_KV_EXPERT_USED_COUNT,
|
LLM_KV_EXPERT_USED_COUNT,
|
||||||
LLM_KV_EXPERT_SHARED_COUNT,
|
LLM_KV_EXPERT_SHARED_COUNT,
|
||||||
|
LLM_KV_EXPERT_GROUP_COUNT,
|
||||||
|
LLM_KV_EXPERT_GROUP_USED_COUNT,
|
||||||
LLM_KV_EXPERT_WEIGHTS_SCALE,
|
LLM_KV_EXPERT_WEIGHTS_SCALE,
|
||||||
LLM_KV_EXPERT_WEIGHTS_NORM,
|
LLM_KV_EXPERT_WEIGHTS_NORM,
|
||||||
LLM_KV_EXPERT_GATING_FUNC,
|
LLM_KV_EXPERT_GATING_FUNC,
|
||||||
|
|
|
||||||
|
|
@ -123,7 +123,7 @@ private:
|
||||||
uint32_t n_seq_max;
|
uint32_t n_seq_max;
|
||||||
uint32_t n_outputs;
|
uint32_t n_outputs;
|
||||||
|
|
||||||
std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
|
std::array<llama_seq_id, 1> seq_id_0 = {{ 0 }}; // default sequence id
|
||||||
|
|
||||||
std::vector<llama_pos> pos;
|
std::vector<llama_pos> pos;
|
||||||
std::vector<int32_t> n_seq_id;
|
std::vector<int32_t> n_seq_id;
|
||||||
|
|
|
||||||
|
|
@ -63,6 +63,8 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
|
||||||
{ "megrez", LLM_CHAT_TEMPLATE_MEGREZ },
|
{ "megrez", LLM_CHAT_TEMPLATE_MEGREZ },
|
||||||
{ "yandex", LLM_CHAT_TEMPLATE_YANDEX },
|
{ "yandex", LLM_CHAT_TEMPLATE_YANDEX },
|
||||||
{ "bailing", LLM_CHAT_TEMPLATE_BAILING },
|
{ "bailing", LLM_CHAT_TEMPLATE_BAILING },
|
||||||
|
{ "bailing-think", LLM_CHAT_TEMPLATE_BAILING_THINK },
|
||||||
|
{ "bailing2", LLM_CHAT_TEMPLATE_BAILING2 },
|
||||||
{ "llama4", LLM_CHAT_TEMPLATE_LLAMA4 },
|
{ "llama4", LLM_CHAT_TEMPLATE_LLAMA4 },
|
||||||
{ "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM },
|
{ "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM },
|
||||||
{ "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE },
|
{ "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE },
|
||||||
|
|
@ -191,6 +193,10 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
|
||||||
return LLM_CHAT_TEMPLATE_YANDEX;
|
return LLM_CHAT_TEMPLATE_YANDEX;
|
||||||
} else if (tmpl_contains("<role>ASSISTANT</role>") && tmpl_contains("'HUMAN'")) {
|
} else if (tmpl_contains("<role>ASSISTANT</role>") && tmpl_contains("'HUMAN'")) {
|
||||||
return LLM_CHAT_TEMPLATE_BAILING;
|
return LLM_CHAT_TEMPLATE_BAILING;
|
||||||
|
} else if (tmpl_contains("<role>ASSISTANT</role>") && tmpl_contains("\"HUMAN\"") && tmpl_contains("<think>")) {
|
||||||
|
return LLM_CHAT_TEMPLATE_BAILING_THINK;
|
||||||
|
} else if (tmpl_contains("<role>ASSISTANT</role>") && tmpl_contains("<role>HUMAN</role>") && tmpl_contains("<|role_end|>")) {
|
||||||
|
return LLM_CHAT_TEMPLATE_BAILING2;
|
||||||
} else if (tmpl_contains("<|header_start|>") && tmpl_contains("<|header_end|>")) {
|
} else if (tmpl_contains("<|header_start|>") && tmpl_contains("<|header_end|>")) {
|
||||||
return LLM_CHAT_TEMPLATE_LLAMA4;
|
return LLM_CHAT_TEMPLATE_LLAMA4;
|
||||||
} else if (tmpl_contains("<|endofuserprompt|>")) {
|
} else if (tmpl_contains("<|endofuserprompt|>")) {
|
||||||
|
|
@ -644,8 +650,8 @@ int32_t llm_chat_apply_template(
|
||||||
if (add_ass) {
|
if (add_ass) {
|
||||||
ss << " Ассистент:[SEP]";
|
ss << " Ассистент:[SEP]";
|
||||||
}
|
}
|
||||||
} else if (tmpl == LLM_CHAT_TEMPLATE_BAILING) {
|
} else if (tmpl == LLM_CHAT_TEMPLATE_BAILING || tmpl == LLM_CHAT_TEMPLATE_BAILING_THINK) {
|
||||||
// Bailing (Ling) template
|
// Bailing (Ling/Ring) template
|
||||||
for (auto message : chat) {
|
for (auto message : chat) {
|
||||||
std::string role(message->role);
|
std::string role(message->role);
|
||||||
|
|
||||||
|
|
@ -658,6 +664,33 @@ int32_t llm_chat_apply_template(
|
||||||
ss << "<role>" << role << "</role>" << message->content;
|
ss << "<role>" << role << "</role>" << message->content;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (add_ass) {
|
||||||
|
ss << "<role>ASSISTANT</role>";
|
||||||
|
|
||||||
|
if (tmpl == LLM_CHAT_TEMPLATE_BAILING_THINK) {
|
||||||
|
ss << "<think>";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if (tmpl == LLM_CHAT_TEMPLATE_BAILING2) {
|
||||||
|
// Bailing2 (Ling 2.0) template
|
||||||
|
bool has_system = !chat.empty() && std::string(chat[0]->role) == "system";
|
||||||
|
|
||||||
|
if (!has_system) {
|
||||||
|
ss << "<role>SYSTEM</role>detailed thinking off<|role_end|>";
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto message : chat) {
|
||||||
|
std::string role(message->role);
|
||||||
|
|
||||||
|
if (role == "user") {
|
||||||
|
role = "HUMAN";
|
||||||
|
} else {
|
||||||
|
std::transform(role.begin(), role.end(), role.begin(), ::toupper);
|
||||||
|
}
|
||||||
|
|
||||||
|
ss << "<role>" << role << "</role>" << message->content << "<|role_end|>";
|
||||||
|
}
|
||||||
|
|
||||||
if (add_ass) {
|
if (add_ass) {
|
||||||
ss << "<role>ASSISTANT</role>";
|
ss << "<role>ASSISTANT</role>";
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -42,6 +42,8 @@ enum llm_chat_template {
|
||||||
LLM_CHAT_TEMPLATE_MEGREZ,
|
LLM_CHAT_TEMPLATE_MEGREZ,
|
||||||
LLM_CHAT_TEMPLATE_YANDEX,
|
LLM_CHAT_TEMPLATE_YANDEX,
|
||||||
LLM_CHAT_TEMPLATE_BAILING,
|
LLM_CHAT_TEMPLATE_BAILING,
|
||||||
|
LLM_CHAT_TEMPLATE_BAILING_THINK,
|
||||||
|
LLM_CHAT_TEMPLATE_BAILING2,
|
||||||
LLM_CHAT_TEMPLATE_LLAMA4,
|
LLM_CHAT_TEMPLATE_LLAMA4,
|
||||||
LLM_CHAT_TEMPLATE_SMOLVLM,
|
LLM_CHAT_TEMPLATE_SMOLVLM,
|
||||||
LLM_CHAT_TEMPLATE_DOTS1,
|
LLM_CHAT_TEMPLATE_DOTS1,
|
||||||
|
|
|
||||||
|
|
@ -2345,7 +2345,8 @@ llama_context * llama_init_from_model(
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (params.pooling_type != model->hparams.pooling_type) {
|
if (params.pooling_type != LLAMA_POOLING_TYPE_UNSPECIFIED &&
|
||||||
|
params.pooling_type != model->hparams.pooling_type) {
|
||||||
//user-specified pooling-type is different from the model default
|
//user-specified pooling-type is different from the model default
|
||||||
LLAMA_LOG_WARN("%s: model default pooling_type is [%d], but [%d] was specified\n", __func__,
|
LLAMA_LOG_WARN("%s: model default pooling_type is [%d], but [%d] was specified\n", __func__,
|
||||||
model->hparams.pooling_type, params.pooling_type);
|
model->hparams.pooling_type, params.pooling_type);
|
||||||
|
|
|
||||||
|
|
@ -261,12 +261,17 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void print_mask(float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
|
static void print_mask(const float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
|
||||||
LLAMA_LOG_DEBUG("%s: === Attention mask ===\n", __func__);
|
LLAMA_LOG_DEBUG("%s: === Attention mask ===\n", __func__);
|
||||||
const char * swa_type_str = (swa_type == LLAMA_SWA_TYPE_NONE) ? "LLAMA_SWA_TYPE_NONE" :
|
const char * swa_type_str = "unknown";
|
||||||
(swa_type == LLAMA_SWA_TYPE_STANDARD) ? "LLAMA_SWA_TYPE_STANDARD" :
|
|
||||||
(swa_type == LLAMA_SWA_TYPE_CHUNKED) ? "LLAMA_SWA_TYPE_CHUNKED" :
|
switch (swa_type) {
|
||||||
(swa_type == LLAMA_SWA_TYPE_SYMMETRIC) ? "LLAMA_SWA_TYPE_SYMMETRIC" : "unknown";
|
case LLAMA_SWA_TYPE_NONE: swa_type_str = "LLAMA_SWA_TYPE_NONE"; break;
|
||||||
|
case LLAMA_SWA_TYPE_STANDARD: swa_type_str = "LLAMA_SWA_TYPE_STANDARD"; break;
|
||||||
|
case LLAMA_SWA_TYPE_CHUNKED: swa_type_str = "LLAMA_SWA_TYPE_CHUNKED"; break;
|
||||||
|
case LLAMA_SWA_TYPE_SYMMETRIC: swa_type_str = "LLAMA_SWA_TYPE_SYMMETRIC"; break;
|
||||||
|
};
|
||||||
|
|
||||||
LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swq_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str);
|
LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swq_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str);
|
||||||
LLAMA_LOG_DEBUG("%s: '0' = can attend, '∞' = masked\n", __func__);
|
LLAMA_LOG_DEBUG("%s: '0' = can attend, '∞' = masked\n", __func__);
|
||||||
LLAMA_LOG_DEBUG("%s: Rows = query tokens, Columns = key/value tokens\n\n", __func__);
|
LLAMA_LOG_DEBUG("%s: Rows = query tokens, Columns = key/value tokens\n\n", __func__);
|
||||||
|
|
@ -295,51 +300,68 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
|
||||||
const int64_t n_kv = ubatch->n_tokens;
|
const int64_t n_kv = ubatch->n_tokens;
|
||||||
const int64_t n_tokens = ubatch->n_tokens;
|
const int64_t n_tokens = ubatch->n_tokens;
|
||||||
|
|
||||||
GGML_ASSERT(kq_mask);
|
const auto fill_mask = [&](float * data, int n_swa, llama_swa_type swa_type) {
|
||||||
GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
|
|
||||||
|
|
||||||
float * data = (float *) kq_mask->data;
|
|
||||||
|
|
||||||
// [TAG_NO_CACHE_ISWA]
|
|
||||||
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "TODO: implement");
|
|
||||||
|
|
||||||
for (int h = 0; h < 1; ++h) {
|
for (int h = 0; h < 1; ++h) {
|
||||||
for (int i1 = 0; i1 < n_tokens; ++i1) {
|
for (int i1 = 0; i1 < n_tokens; ++i1) {
|
||||||
const llama_seq_id s1 = ubatch->seq_id[i1][0];
|
const llama_seq_id s1 = ubatch->seq_id[i1][0];
|
||||||
|
const llama_pos p1 = ubatch->pos[i1];
|
||||||
|
|
||||||
|
const uint64_t idst = h*(n_kv*n_tokens) + i1*n_kv;
|
||||||
|
|
||||||
for (int i0 = 0; i0 < n_tokens; ++i0) {
|
for (int i0 = 0; i0 < n_tokens; ++i0) {
|
||||||
float f = -INFINITY;
|
|
||||||
|
|
||||||
for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) {
|
|
||||||
const llama_seq_id s0 = ubatch->seq_id[i0][0];
|
const llama_seq_id s0 = ubatch->seq_id[i0][0];
|
||||||
|
const llama_pos p0 = ubatch->pos[i0];
|
||||||
|
|
||||||
|
// mask different sequences
|
||||||
if (s0 != s1) {
|
if (s0 != s1) {
|
||||||
continue; // skip different sequences
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cparams.causal_attn && ubatch->pos[i0] > ubatch->pos[i1]) {
|
// mask future tokens
|
||||||
continue; // skip future tokens for causal attention
|
if (cparams.causal_attn && p0 > p1) {
|
||||||
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: this does not take into account that some layers are SWA and others are note (i.e. iSWA) [TAG_NO_CACHE_ISWA]
|
// apply SWA if any
|
||||||
//if (hparams.is_masked_swa(ubatch->pos[i0], ubatch->pos[i1])) {
|
if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) {
|
||||||
// continue; // skip masked tokens for SWA
|
continue;
|
||||||
//}
|
}
|
||||||
|
|
||||||
// TODO: reimplement this like in llama_kv_cache_unified
|
data[idst + i0] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
|
||||||
if (hparams.use_alibi) {
|
|
||||||
f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]);
|
|
||||||
} else {
|
|
||||||
f = 0.0f;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
{
|
||||||
|
GGML_ASSERT(self_kq_mask);
|
||||||
|
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
|
||||||
|
|
||||||
|
float * data = (float *) self_kq_mask->data;
|
||||||
|
|
||||||
|
std::fill(data, data + ggml_nelements(self_kq_mask), -INFINITY);
|
||||||
|
|
||||||
|
fill_mask(data, 0, LLAMA_SWA_TYPE_NONE);
|
||||||
|
|
||||||
|
if (debug) {
|
||||||
|
print_mask(data, n_tokens, n_kv, 0, LLAMA_SWA_TYPE_NONE);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
|
||||||
|
GGML_ASSERT(self_kq_mask_swa);
|
||||||
|
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
|
||||||
|
|
||||||
|
float * data = (float *) self_kq_mask_swa->data;
|
||||||
|
|
||||||
|
std::fill(data, data + ggml_nelements(self_kq_mask_swa), -INFINITY);
|
||||||
|
|
||||||
|
fill_mask(data, hparams.n_swa, hparams.swa_type);
|
||||||
|
|
||||||
if (debug) {
|
if (debug) {
|
||||||
print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type);
|
print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) {
|
void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) {
|
||||||
|
|
@ -928,6 +950,31 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
||||||
cb(selection_probs, "ffn_moe_probs_biased", il);
|
cb(selection_probs, "ffn_moe_probs_biased", il);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// select top n_group_used expert groups
|
||||||
|
// https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/e815299b0bcbac849fa540c768ef21845365c9eb/modeling_deepseek.py#L440-L457
|
||||||
|
if (hparams.n_expert_groups > 1 && n_tokens > 0) {
|
||||||
|
const int64_t n_exp_per_group = n_expert / hparams.n_expert_groups;
|
||||||
|
|
||||||
|
// organize experts into n_expert_groups
|
||||||
|
ggml_tensor * selection_groups = ggml_reshape_3d(ctx0, selection_probs, n_exp_per_group, hparams.n_expert_groups, n_tokens); // [n_exp_per_group, n_expert_groups, n_tokens]
|
||||||
|
|
||||||
|
ggml_tensor * group_scores = ggml_top_k(ctx0, selection_groups, 2); // [2, n_expert_groups, n_tokens]
|
||||||
|
group_scores = ggml_get_rows(ctx0, ggml_reshape_4d(ctx0, selection_groups, 1, selection_groups->ne[0], selection_groups->ne[1], selection_groups->ne[2]), group_scores); // [1, 2, n_expert_groups, n_tokens]
|
||||||
|
|
||||||
|
// get top n_group_used expert groups
|
||||||
|
group_scores = ggml_sum_rows(ctx0, ggml_reshape_3d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2], group_scores->ne[3])); // [1, n_expert_groups, n_tokens]
|
||||||
|
group_scores = ggml_reshape_2d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2]); // [n_expert_groups, n_tokens]
|
||||||
|
|
||||||
|
ggml_tensor * expert_groups = ggml_top_k(ctx0, group_scores, hparams.n_group_used); // [n_group_used, n_tokens]
|
||||||
|
cb(expert_groups, "ffn_moe_group_topk", il);
|
||||||
|
|
||||||
|
// mask out the other groups
|
||||||
|
selection_probs = ggml_get_rows(ctx0, selection_groups, expert_groups); // [n_exp_per_group, n_group_used, n_tokens]
|
||||||
|
selection_probs = ggml_set_rows(ctx0, ggml_scale_bias(ctx0, selection_groups, 0.0f, -INFINITY), selection_probs, expert_groups); // [n_exp_per_group, n_expert_groups, n_tokens]
|
||||||
|
selection_probs = ggml_reshape_2d(ctx0, selection_probs, n_expert, n_tokens); // [n_expert, n_tokens]
|
||||||
|
cb(selection_probs, "ffn_moe_probs_masked", il);
|
||||||
|
}
|
||||||
|
|
||||||
// select experts
|
// select experts
|
||||||
ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
|
ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
|
||||||
cb(selected_experts->src[0], "ffn_moe_argsort", il);
|
cb(selected_experts->src[0], "ffn_moe_argsort", il);
|
||||||
|
|
@ -959,6 +1006,11 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
||||||
ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); // [1, n_tokens]
|
ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); // [1, n_tokens]
|
||||||
cb(weights_sum, "ffn_moe_weights_sum", il);
|
cb(weights_sum, "ffn_moe_weights_sum", il);
|
||||||
|
|
||||||
|
if (arch == LLM_ARCH_BAILINGMOE2) {
|
||||||
|
weights_sum = ggml_scale_bias(ctx0, weights_sum, 1.0, 1e-20);
|
||||||
|
cb(weights_sum, "ffn_moe_weights_sum_biased", il);
|
||||||
|
}
|
||||||
|
|
||||||
weights = ggml_div(ctx0, weights, weights_sum); // [n_expert_used, n_tokens]
|
weights = ggml_div(ctx0, weights, weights_sum); // [n_expert_used, n_tokens]
|
||||||
cb(weights, "ffn_moe_weights_norm", il);
|
cb(weights, "ffn_moe_weights_norm", il);
|
||||||
|
|
||||||
|
|
@ -1299,12 +1351,9 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
||||||
k = ggml_permute(ctx0, k, 0, 2, 1, 3);
|
k = ggml_permute(ctx0, k, 0, 2, 1, 3);
|
||||||
v = ggml_permute(ctx0, v, 0, 2, 1, 3);
|
v = ggml_permute(ctx0, v, 0, 2, 1, 3);
|
||||||
|
|
||||||
const auto n_kv = k->ne[1];
|
|
||||||
|
|
||||||
ggml_tensor * cur;
|
ggml_tensor * cur;
|
||||||
|
|
||||||
// TODO: replace hardcoded padding with ggml-provided padding
|
if (cparams.flash_attn && kq_b == nullptr) {
|
||||||
if (cparams.flash_attn && (n_kv % 256 == 0) && kq_b == nullptr) {
|
|
||||||
GGML_ASSERT(kq_b == nullptr && "Flash attention does not support KQ bias yet");
|
GGML_ASSERT(kq_b == nullptr && "Flash attention does not support KQ bias yet");
|
||||||
|
|
||||||
if (v_trans) {
|
if (v_trans) {
|
||||||
|
|
@ -1419,10 +1468,20 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
|
||||||
auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
|
auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
|
||||||
|
|
||||||
// note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
|
// note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
|
||||||
inp->kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
|
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
|
||||||
ggml_set_input(inp->kq_mask);
|
ggml_set_input(inp->self_kq_mask);
|
||||||
|
|
||||||
inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask;
|
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||||
|
|
||||||
|
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
|
||||||
|
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
|
||||||
|
ggml_set_input(inp->self_kq_mask_swa);
|
||||||
|
|
||||||
|
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
||||||
|
} else {
|
||||||
|
inp->self_kq_mask_swa = nullptr;
|
||||||
|
inp->self_kq_mask_swa_cnv = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
return (llm_graph_input_attn_no_cache *) res->add_input(std::move(inp));
|
return (llm_graph_input_attn_no_cache *) res->add_input(std::move(inp));
|
||||||
}
|
}
|
||||||
|
|
@ -1447,7 +1506,9 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||||
ggml_build_forward_expand(gf, k_cur);
|
ggml_build_forward_expand(gf, k_cur);
|
||||||
ggml_build_forward_expand(gf, v_cur);
|
ggml_build_forward_expand(gf, v_cur);
|
||||||
|
|
||||||
const auto & kq_mask = inp->get_kq_mask();
|
const bool is_swa = hparams.is_swa(il);
|
||||||
|
|
||||||
|
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
|
||||||
|
|
||||||
// [TAG_NO_CACHE_PAD]
|
// [TAG_NO_CACHE_PAD]
|
||||||
// TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams
|
// TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams
|
||||||
|
|
|
||||||
|
|
@ -257,10 +257,14 @@ public:
|
||||||
|
|
||||||
void set_input(const llama_ubatch * ubatch) override;
|
void set_input(const llama_ubatch * ubatch) override;
|
||||||
|
|
||||||
ggml_tensor * get_kq_mask() const { return kq_mask_cnv; }
|
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
||||||
|
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
|
||||||
|
|
||||||
ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch, 1, 1]
|
// n_tokens == n_batch
|
||||||
ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch, 1, 1]
|
ggml_tensor * self_kq_mask = nullptr; // F32 [n_tokens, n_batch/n_stream, 1, n_stream]
|
||||||
|
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_tokens, n_batch/n_stream, 1, n_stream]
|
||||||
|
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_tokens, n_batch/n_stream, 1, n_stream]
|
||||||
|
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_tokens, n_batch/n_stream, 1, n_stream]
|
||||||
|
|
||||||
const llama_hparams hparams;
|
const llama_hparams hparams;
|
||||||
const llama_cparams cparams;
|
const llama_cparams cparams;
|
||||||
|
|
|
||||||
|
|
@ -74,6 +74,8 @@ struct llama_hparams {
|
||||||
uint32_t n_ff_chexp = 0;
|
uint32_t n_ff_chexp = 0;
|
||||||
uint32_t n_expert_shared = 0;
|
uint32_t n_expert_shared = 0;
|
||||||
uint32_t n_norm_groups = 0;
|
uint32_t n_norm_groups = 0;
|
||||||
|
uint32_t n_expert_groups = 0;
|
||||||
|
uint32_t n_group_used = 0;
|
||||||
uint32_t n_group_experts = 0;
|
uint32_t n_group_experts = 0;
|
||||||
|
|
||||||
float expert_group_scale = 0.05f;
|
float expert_group_scale = 0.05f;
|
||||||
|
|
|
||||||
|
|
@ -114,9 +114,12 @@ const char * llm_type_name(llm_type type) {
|
||||||
case LLM_TYPE_17B_16E: return "17Bx16E (Scout)";
|
case LLM_TYPE_17B_16E: return "17Bx16E (Scout)";
|
||||||
case LLM_TYPE_17B_128E: return "17Bx128E (Maverick)";
|
case LLM_TYPE_17B_128E: return "17Bx128E (Maverick)";
|
||||||
case LLM_TYPE_A13B: return "A13B";
|
case LLM_TYPE_A13B: return "A13B";
|
||||||
|
case LLM_TYPE_7B_A1B: return "7B.A1B";
|
||||||
case LLM_TYPE_8B_A1B: return "8B.A1B";
|
case LLM_TYPE_8B_A1B: return "8B.A1B";
|
||||||
|
case LLM_TYPE_16B_A1B: return "16B.A1B";
|
||||||
case LLM_TYPE_21B_A3B: return "21B.A3B";
|
case LLM_TYPE_21B_A3B: return "21B.A3B";
|
||||||
case LLM_TYPE_30B_A3B: return "30B.A3B";
|
case LLM_TYPE_30B_A3B: return "30B.A3B";
|
||||||
|
case LLM_TYPE_100B_A6B: return "100B.A6B";
|
||||||
case LLM_TYPE_106B_A12B: return "106B.A12B";
|
case LLM_TYPE_106B_A12B: return "106B.A12B";
|
||||||
case LLM_TYPE_235B_A22B: return "235B.A22B";
|
case LLM_TYPE_235B_A22B: return "235B.A22B";
|
||||||
case LLM_TYPE_300B_A47B: return "300B.A47B";
|
case LLM_TYPE_300B_A47B: return "300B.A47B";
|
||||||
|
|
@ -401,6 +404,19 @@ static buft_list_t make_gpu_buft_list(ggml_backend_dev_t dev, llama_split_mode s
|
||||||
// add the device default buffer type
|
// add the device default buffer type
|
||||||
buft_list.emplace_back(dev, ggml_backend_dev_buffer_type(dev));
|
buft_list.emplace_back(dev, ggml_backend_dev_buffer_type(dev));
|
||||||
|
|
||||||
|
// add the device extra buffer type (if any)
|
||||||
|
ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev);
|
||||||
|
auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t)
|
||||||
|
ggml_backend_reg_get_proc_address(reg, "ggml_backend_dev_get_extra_bufts");
|
||||||
|
|
||||||
|
if (ggml_backend_dev_get_extra_bufts_fn) {
|
||||||
|
ggml_backend_buffer_type_t * extra_bufts = ggml_backend_dev_get_extra_bufts_fn(dev);
|
||||||
|
while (extra_bufts && *extra_bufts) {
|
||||||
|
buft_list.emplace_back(dev, *extra_bufts);
|
||||||
|
++extra_bufts;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return buft_list;
|
return buft_list;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -421,11 +437,8 @@ struct llama_model::impl {
|
||||||
llama_mlocks mlock_bufs;
|
llama_mlocks mlock_bufs;
|
||||||
llama_mlocks mlock_mmaps;
|
llama_mlocks mlock_mmaps;
|
||||||
|
|
||||||
// contexts where the model tensors metadata is stored
|
// contexts where the model tensors metadata is stored as well ass the corresponding buffers:
|
||||||
std::vector<ggml_context_ptr> ctxs;
|
std::vector<std::pair<ggml_context_ptr, ggml_backend_buffer_ptr>> ctxs_bufs;
|
||||||
|
|
||||||
// the model memory buffers for the tensor data
|
|
||||||
std::vector<ggml_backend_buffer_ptr> bufs;
|
|
||||||
|
|
||||||
buft_list_t cpu_buft_list;
|
buft_list_t cpu_buft_list;
|
||||||
std::map<ggml_backend_dev_t, buft_list_t> gpu_buft_list;
|
std::map<ggml_backend_dev_t, buft_list_t> gpu_buft_list;
|
||||||
|
|
@ -478,7 +491,8 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||||
ml.get_key(LLM_KV_GENERAL_NAME, name, false);
|
ml.get_key(LLM_KV_GENERAL_NAME, name, false);
|
||||||
|
|
||||||
// everything past this point is not vocab-related
|
// everything past this point is not vocab-related
|
||||||
if (hparams.vocab_only) {
|
// for CLIP models, we only need to load tensors, no hparams
|
||||||
|
if (hparams.vocab_only || ml.get_arch() == LLM_ARCH_CLIP) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -487,6 +501,8 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||||
ml.get_key(LLM_KV_BLOCK_COUNT, hparams.n_layer);
|
ml.get_key(LLM_KV_BLOCK_COUNT, hparams.n_layer);
|
||||||
ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert, false);
|
ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert, false);
|
||||||
ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false);
|
ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false);
|
||||||
|
ml.get_key(LLM_KV_EXPERT_GROUP_COUNT, hparams.n_expert_groups, false);
|
||||||
|
ml.get_key(LLM_KV_EXPERT_GROUP_USED_COUNT, hparams.n_group_used, false);
|
||||||
|
|
||||||
if (arch == LLM_ARCH_WAVTOKENIZER_DEC) {
|
if (arch == LLM_ARCH_WAVTOKENIZER_DEC) {
|
||||||
ml.get_key(LLM_KV_FEATURES_LENGTH, hparams.n_embd_features);
|
ml.get_key(LLM_KV_FEATURES_LENGTH, hparams.n_embd_features);
|
||||||
|
|
@ -502,8 +518,15 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||||
GGML_ASSERT(hparams.n_expert_used <= hparams.n_expert);
|
GGML_ASSERT(hparams.n_expert_used <= hparams.n_expert);
|
||||||
if (hparams.n_expert > 0) {
|
if (hparams.n_expert > 0) {
|
||||||
GGML_ASSERT(hparams.n_expert_used > 0);
|
GGML_ASSERT(hparams.n_expert_used > 0);
|
||||||
|
GGML_ASSERT(hparams.n_expert_groups < hparams.n_expert);
|
||||||
|
if (hparams.n_expert_groups > 1) {
|
||||||
|
GGML_ASSERT(hparams.n_expert % hparams.n_expert_groups == 0);
|
||||||
|
GGML_ASSERT(hparams.n_group_used > 0);
|
||||||
|
GGML_ASSERT(hparams.n_group_used < hparams.n_expert_groups);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
GGML_ASSERT(hparams.n_expert_used == 0);
|
GGML_ASSERT(hparams.n_expert_used == 0);
|
||||||
|
GGML_ASSERT(hparams.n_expert_groups == 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0);
|
std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0);
|
||||||
|
|
@ -1845,8 +1868,10 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||||
|
|
||||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||||
|
|
||||||
switch (hparams.n_layer) {
|
switch (hparams.n_embd) {
|
||||||
// TODO: Add llm type label (not sure this is useful)
|
case 1536: type = LLM_TYPE_7B_A1B; break;
|
||||||
|
case 2048: case 2560: type = LLM_TYPE_3B; break;
|
||||||
|
case 4096: type = LLM_TYPE_32B; break;
|
||||||
default: type = LLM_TYPE_UNKNOWN;
|
default: type = LLM_TYPE_UNKNOWN;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1902,6 +1927,29 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||||
default: type = LLM_TYPE_UNKNOWN;
|
default: type = LLM_TYPE_UNKNOWN;
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case LLM_ARCH_BAILINGMOE2:
|
||||||
|
{
|
||||||
|
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||||
|
ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead);
|
||||||
|
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
|
||||||
|
ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp);
|
||||||
|
ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared);
|
||||||
|
ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale);
|
||||||
|
ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false);
|
||||||
|
ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func);
|
||||||
|
ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false);
|
||||||
|
|
||||||
|
// TODO: when MTP is implemented, this should probably be updated if needed
|
||||||
|
hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers;
|
||||||
|
|
||||||
|
switch (hparams.n_layer) {
|
||||||
|
case 20: type = LLM_TYPE_16B_A1B; break;
|
||||||
|
case 21: type = LLM_TYPE_16B_A1B; break;
|
||||||
|
case 32: type = LLM_TYPE_100B_A6B; break;
|
||||||
|
case 33: type = LLM_TYPE_100B_A6B; break;
|
||||||
|
default: type = LLM_TYPE_UNKNOWN;
|
||||||
|
}
|
||||||
|
} break;
|
||||||
case LLM_ARCH_DOTS1:
|
case LLM_ARCH_DOTS1:
|
||||||
{
|
{
|
||||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||||
|
|
@ -2196,7 +2244,14 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||||
max_n_tensors += n_layer*2; // duplicated rope freq tensors
|
max_n_tensors += n_layer*2; // duplicated rope freq tensors
|
||||||
const size_t ctx_size = ggml_tensor_overhead()*max_n_tensors;
|
const size_t ctx_size = ggml_tensor_overhead()*max_n_tensors;
|
||||||
|
|
||||||
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
|
// define a comparator for the buft -> ctx map to ensure that the order is well-defined:
|
||||||
|
struct ggml_backend_buft_comparator {
|
||||||
|
bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const {
|
||||||
|
return ggml_backend_buft_name(lhs) < ggml_backend_buft_name(rhs);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
std::map<ggml_backend_buffer_type_t, ggml_context_ptr, ggml_backend_buft_comparator> ctx_map;
|
||||||
|
|
||||||
auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
|
auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
|
||||||
auto it = ctx_map.find(buft);
|
auto it = ctx_map.find(buft);
|
||||||
if (it == ctx_map.end()) {
|
if (it == ctx_map.end()) {
|
||||||
|
|
@ -2211,12 +2266,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||||
throw std::runtime_error(format("failed to create ggml context"));
|
throw std::runtime_error(format("failed to create ggml context"));
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx_map[buft] = ctx;
|
ctx_map.emplace(buft, ctx);
|
||||||
pimpl->ctxs.emplace_back(ctx);
|
|
||||||
|
|
||||||
return ctx;
|
return ctx;
|
||||||
}
|
}
|
||||||
return it->second;
|
return it->second.get();
|
||||||
};
|
};
|
||||||
|
|
||||||
const auto TENSOR_DUPLICATED = llama_model_loader::TENSOR_DUPLICATED;
|
const auto TENSOR_DUPLICATED = llama_model_loader::TENSOR_DUPLICATED;
|
||||||
|
|
@ -5534,6 +5588,70 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||||
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
|
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case LLM_ARCH_BAILINGMOE2:
|
||||||
|
{
|
||||||
|
const int64_t n_ff_exp = hparams.n_ff_exp;
|
||||||
|
const int64_t n_expert_shared = hparams.n_expert_shared;
|
||||||
|
|
||||||
|
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||||
|
|
||||||
|
// output
|
||||||
|
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
||||||
|
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
|
||||||
|
|
||||||
|
GGML_ASSERT(n_expert > 0 && "n_expert must be > 0 for bailingmoe2");
|
||||||
|
GGML_ASSERT(n_expert_used > 0 && "n_expert_used must be > 0 for bailingmoe2");
|
||||||
|
|
||||||
|
for (int i = 0; i < n_layer; ++i) {
|
||||||
|
int flags = 0;
|
||||||
|
if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) {
|
||||||
|
// skip all tensors in the NextN layers
|
||||||
|
flags |= TENSOR_SKIP;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto & layer = layers[i];
|
||||||
|
|
||||||
|
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags);
|
||||||
|
|
||||||
|
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, flags);
|
||||||
|
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, flags);
|
||||||
|
|
||||||
|
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, flags);
|
||||||
|
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, flags);
|
||||||
|
|
||||||
|
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags);
|
||||||
|
|
||||||
|
if (static_cast<uint32_t>(i) >= hparams.n_layer_dense_lead) { // MoE layers
|
||||||
|
const int64_t n_ff_shexp = (hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff_exp) * n_expert_shared;
|
||||||
|
|
||||||
|
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, flags);
|
||||||
|
layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED | flags);
|
||||||
|
|
||||||
|
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags);
|
||||||
|
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, flags);
|
||||||
|
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags);
|
||||||
|
|
||||||
|
layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp}, flags);
|
||||||
|
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, flags);
|
||||||
|
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, flags);
|
||||||
|
} else { // Dense layers
|
||||||
|
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, flags);
|
||||||
|
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, flags);
|
||||||
|
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, flags);
|
||||||
|
}
|
||||||
|
|
||||||
|
// NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers
|
||||||
|
if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) {
|
||||||
|
layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags);
|
||||||
|
layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED | flags);
|
||||||
|
layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags);
|
||||||
|
layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags);
|
||||||
|
layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED | flags);
|
||||||
|
layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED | flags);
|
||||||
|
layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, flags);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} break;
|
||||||
case LLM_ARCH_DOTS1:
|
case LLM_ARCH_DOTS1:
|
||||||
{
|
{
|
||||||
const int64_t n_ff_exp = hparams.n_ff_exp;
|
const int64_t n_ff_exp = hparams.n_ff_exp;
|
||||||
|
|
@ -6079,16 +6197,15 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||||
pimpl->mappings.reserve(ml.mappings.size());
|
pimpl->mappings.reserve(ml.mappings.size());
|
||||||
|
|
||||||
// create the backend buffers
|
// create the backend buffers
|
||||||
std::vector<std::pair<ggml_context *, llama_buf_map>> ctx_bufs;
|
std::vector<std::pair<ggml_context *, llama_buf_map>> ctx_buf_maps;
|
||||||
ctx_bufs.reserve(ctx_map.size());
|
ctx_buf_maps.reserve(ctx_map.size());
|
||||||
|
|
||||||
// Ensure we have enough capacity for the maximum backend buffer we will potentially create
|
// Ensure we have enough capacity for the maximum backend buffer we will potentially create
|
||||||
const size_t n_max_backend_buffer = ctx_map.size() * ml.files.size();
|
const size_t n_max_backend_buffer = ctx_map.size() * ml.files.size();
|
||||||
pimpl->bufs.reserve(n_max_backend_buffer);
|
pimpl->ctxs_bufs.reserve(n_max_backend_buffer);
|
||||||
|
|
||||||
for (auto & it : ctx_map) {
|
for (auto & [buft, ctx_ptr] : ctx_map) {
|
||||||
ggml_backend_buffer_type_t buft = it.first;
|
ggml_context * ctx = ctx_ptr.get();
|
||||||
ggml_context * ctx = it.second;
|
|
||||||
|
|
||||||
// skip contexts without tensors
|
// skip contexts without tensors
|
||||||
if (ggml_get_first_tensor(ctx) == nullptr) {
|
if (ggml_get_first_tensor(ctx) == nullptr) {
|
||||||
|
|
@ -6112,6 +6229,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||||
bool buffer_from_host_ptr_supported = props.caps.buffer_from_host_ptr;
|
bool buffer_from_host_ptr_supported = props.caps.buffer_from_host_ptr;
|
||||||
bool is_default_buft = buft == ggml_backend_dev_buffer_type(dev);
|
bool is_default_buft = buft == ggml_backend_dev_buffer_type(dev);
|
||||||
|
|
||||||
|
ggml_backend_buffer_t buf = nullptr;
|
||||||
if (ml.use_mmap && use_mmap_buffer && buffer_from_host_ptr_supported && is_default_buft) {
|
if (ml.use_mmap && use_mmap_buffer && buffer_from_host_ptr_supported && is_default_buft) {
|
||||||
for (uint32_t idx = 0; idx < ml.files.size(); idx++) {
|
for (uint32_t idx = 0; idx < ml.files.size(); idx++) {
|
||||||
// only the mmap region containing the tensors in the model is mapped to the backend buffer
|
// only the mmap region containing the tensors in the model is mapped to the backend buffer
|
||||||
|
|
@ -6124,20 +6242,18 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
const size_t max_size = ggml_get_max_tensor_size(ctx);
|
const size_t max_size = ggml_get_max_tensor_size(ctx);
|
||||||
ggml_backend_buffer_t buf = ggml_backend_dev_buffer_from_host_ptr(dev, (char *) addr + first, last - first, max_size);
|
buf = ggml_backend_dev_buffer_from_host_ptr(dev, (char *) addr + first, last - first, max_size);
|
||||||
if (buf == nullptr) {
|
if (buf == nullptr) {
|
||||||
throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft)));
|
throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft)));
|
||||||
}
|
}
|
||||||
pimpl->bufs.emplace_back(buf);
|
|
||||||
buf_map.emplace(idx, buf);
|
buf_map.emplace(idx, buf);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
|
buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
|
||||||
if (buf == nullptr) {
|
if (buf == nullptr) {
|
||||||
throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft)));
|
throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft)));
|
||||||
}
|
}
|
||||||
pimpl->bufs.emplace_back(buf);
|
|
||||||
if (use_mlock && ggml_backend_buffer_is_host(buf)) {
|
if (use_mlock && ggml_backend_buffer_is_host(buf)) {
|
||||||
pimpl->mlock_bufs.emplace_back(new llama_mlock);
|
pimpl->mlock_bufs.emplace_back(new llama_mlock);
|
||||||
auto & mlock_buf = pimpl->mlock_bufs.back();
|
auto & mlock_buf = pimpl->mlock_bufs.back();
|
||||||
|
|
@ -6148,10 +6264,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||||
buf_map.emplace(idx, buf);
|
buf_map.emplace(idx, buf);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
pimpl->ctxs_bufs.emplace_back(std::move(ctx_ptr), buf);
|
||||||
if (pimpl->bufs.empty()) {
|
|
||||||
throw std::runtime_error("failed to allocate buffer");
|
|
||||||
}
|
|
||||||
|
|
||||||
for (auto & buf : buf_map) {
|
for (auto & buf : buf_map) {
|
||||||
// indicate that this buffer contains weights
|
// indicate that this buffer contains weights
|
||||||
|
|
@ -6159,7 +6272,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||||
ggml_backend_buffer_set_usage(buf.second, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
|
ggml_backend_buffer_set_usage(buf.second, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx_bufs.emplace_back(ctx, buf_map);
|
ctx_buf_maps.emplace_back(ctx, buf_map);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (llama_supports_gpu_offload()) {
|
if (llama_supports_gpu_offload()) {
|
||||||
|
|
@ -6177,22 +6290,20 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// print memory requirements per buffer type
|
// print memory requirements per buffer type
|
||||||
for (auto & buf : pimpl->bufs) {
|
for (auto & [_, buf] : pimpl->ctxs_bufs) {
|
||||||
LLAMA_LOG_INFO("%s: %12s model buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf.get()), ggml_backend_buffer_get_size(buf.get()) / 1024.0 / 1024.0);
|
LLAMA_LOG_INFO("%s: %12s model buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf.get()), ggml_backend_buffer_get_size(buf.get()) / 1024.0 / 1024.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
// populate tensors_by_name
|
// populate tensors_by_name
|
||||||
for (auto & ctx : pimpl->ctxs) {
|
for (auto & [ctx, _] : pimpl->ctxs_bufs) {
|
||||||
for (auto * cur = ggml_get_first_tensor(ctx.get()); cur != NULL; cur = ggml_get_next_tensor(ctx.get(), cur)) {
|
for (auto * cur = ggml_get_first_tensor(ctx.get()); cur != NULL; cur = ggml_get_next_tensor(ctx.get(), cur)) {
|
||||||
tensors_by_name.emplace_back(ggml_get_name(cur), cur);
|
tensors_by_name.emplace_back(ggml_get_name(cur), cur);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// load tensor data
|
// load tensor data
|
||||||
for (auto & it : ctx_bufs) {
|
for (auto & [ctx, buf_map] : ctx_buf_maps) {
|
||||||
ggml_context * ctx = it.first;
|
if (!ml.load_all_data(ctx, buf_map, use_mlock ? &pimpl->mlock_mmaps : NULL, params.progress_callback, params.progress_callback_user_data)) {
|
||||||
auto & bufs = it.second;
|
|
||||||
if (!ml.load_all_data(ctx, bufs, use_mlock ? &pimpl->mlock_mmaps : NULL, params.progress_callback, params.progress_callback_user_data)) {
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -6232,8 +6343,8 @@ size_t llama_model::n_devices() const {
|
||||||
|
|
||||||
std::map<ggml_backend_buffer_type_t, size_t> llama_model::memory_breakdown() const {
|
std::map<ggml_backend_buffer_type_t, size_t> llama_model::memory_breakdown() const {
|
||||||
std::map<ggml_backend_buffer_type_t, size_t> ret;
|
std::map<ggml_backend_buffer_type_t, size_t> ret;
|
||||||
for (const ggml_backend_buffer_ptr & buf_ptr : pimpl->bufs) {
|
for (const auto & [_, buf] : pimpl->ctxs_bufs) {
|
||||||
ret[ggml_backend_buffer_get_type(buf_ptr.get())] += ggml_backend_buffer_get_size(buf_ptr.get());
|
ret[ggml_backend_buffer_get_type(buf.get())] += ggml_backend_buffer_get_size(buf.get());
|
||||||
}
|
}
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
@ -6396,6 +6507,19 @@ void llama_model::print_info() const {
|
||||||
LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm);
|
LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (arch == LLM_ARCH_BAILINGMOE2) {
|
||||||
|
LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead);
|
||||||
|
LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
|
||||||
|
LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp);
|
||||||
|
LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared);
|
||||||
|
LLAMA_LOG_INFO("%s: n_expert_groups = %d\n", __func__, hparams.n_expert_groups);
|
||||||
|
LLAMA_LOG_INFO("%s: n_group_used = %d\n", __func__, hparams.n_group_used);
|
||||||
|
LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale);
|
||||||
|
LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm);
|
||||||
|
LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func));
|
||||||
|
LLAMA_LOG_INFO("%s: nextn_predict_layers = %d\n", __func__, hparams.nextn_predict_layers);
|
||||||
|
}
|
||||||
|
|
||||||
if (arch == LLM_ARCH_SMALLTHINKER || arch == LLM_ARCH_LFM2MOE) {
|
if (arch == LLM_ARCH_SMALLTHINKER || arch == LLM_ARCH_LFM2MOE) {
|
||||||
LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
|
LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
|
||||||
LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func));
|
LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func));
|
||||||
|
|
@ -11401,8 +11525,8 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llm_build_gemma_embedding_iswa : public llm_graph_context {
|
struct llm_build_gemma_embedding : public llm_graph_context {
|
||||||
llm_build_gemma_embedding_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
llm_build_gemma_embedding(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||||
const int64_t n_embd_head = hparams.n_embd_head_k;
|
const int64_t n_embd_head = hparams.n_embd_head_k;
|
||||||
|
|
||||||
ggml_tensor * cur;
|
ggml_tensor * cur;
|
||||||
|
|
@ -11419,8 +11543,7 @@ struct llm_build_gemma_embedding_iswa : public llm_graph_context {
|
||||||
// inp_pos - contains the positions
|
// inp_pos - contains the positions
|
||||||
ggml_tensor * inp_pos = build_inp_pos();
|
ggml_tensor * inp_pos = build_inp_pos();
|
||||||
|
|
||||||
// TODO: support cacheless iSWA embeddings [TAG_NO_CACHE_ISWA]
|
auto * inp_attn = build_attn_inp_no_cache();
|
||||||
auto * inp_attn = build_attn_inp_kv_iswa();
|
|
||||||
|
|
||||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||||
|
|
||||||
|
|
@ -17245,6 +17368,150 @@ struct llm_build_bailingmoe : public llm_graph_context {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct llm_build_bailingmoe2 : public llm_graph_context {
|
||||||
|
llm_build_bailingmoe2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||||
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||||
|
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
|
||||||
|
|
||||||
|
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||||
|
|
||||||
|
ggml_tensor * cur;
|
||||||
|
ggml_tensor * inpL;
|
||||||
|
|
||||||
|
inpL = build_inp_embd(model.tok_embd);
|
||||||
|
|
||||||
|
// inp_pos - contains the positions
|
||||||
|
ggml_tensor * inp_pos = build_inp_pos();
|
||||||
|
|
||||||
|
auto * inp_attn = build_attn_inp_kv();
|
||||||
|
|
||||||
|
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||||
|
|
||||||
|
const int n_transformer_layers = n_layer - hparams.nextn_predict_layers;
|
||||||
|
for (int il = 0; il < n_transformer_layers; ++il) {
|
||||||
|
ggml_tensor * inpSA = inpL;
|
||||||
|
|
||||||
|
// norm
|
||||||
|
cur = build_norm(inpL,
|
||||||
|
model.layers[il].attn_norm, NULL,
|
||||||
|
LLM_NORM_RMS, il);
|
||||||
|
cb(cur, "attn_norm", il);
|
||||||
|
|
||||||
|
// self_attention
|
||||||
|
{
|
||||||
|
cur = build_lora_mm(model.layers[il].wqkv, cur);
|
||||||
|
cb(cur, "wqkv", il);
|
||||||
|
|
||||||
|
ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
|
||||||
|
ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd));
|
||||||
|
ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
|
||||||
|
|
||||||
|
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
|
||||||
|
cb(Qcur, "Qcur_normed", il);
|
||||||
|
|
||||||
|
Qcur = ggml_rope_ext(
|
||||||
|
ctx0, Qcur, inp_pos, nullptr,
|
||||||
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
|
);
|
||||||
|
|
||||||
|
Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
|
||||||
|
cb(Kcur, "Kcur_normed", il);
|
||||||
|
|
||||||
|
Kcur = ggml_rope_ext(
|
||||||
|
ctx0, Kcur, inp_pos, nullptr,
|
||||||
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
|
);
|
||||||
|
|
||||||
|
cb(Qcur, "Qcur", il);
|
||||||
|
cb(Kcur, "Kcur", il);
|
||||||
|
cb(Vcur, "Vcur", il);
|
||||||
|
|
||||||
|
cur = build_attn(inp_attn,
|
||||||
|
model.layers[il].wo, model.layers[il].bo,
|
||||||
|
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (il == n_transformer_layers - 1 && inp_out_ids) {
|
||||||
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||||
|
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor * sa_out = ggml_add(ctx0, cur, inpSA);
|
||||||
|
cb(sa_out, "sa_out", il);
|
||||||
|
|
||||||
|
// MoE branch
|
||||||
|
cur = build_norm(sa_out,
|
||||||
|
model.layers[il].ffn_norm, NULL,
|
||||||
|
LLM_NORM_RMS, il);
|
||||||
|
cb(cur, "ffn_norm", il);
|
||||||
|
|
||||||
|
if (static_cast<uint32_t>(il) < hparams.n_layer_dense_lead) {
|
||||||
|
cur = build_ffn(cur,
|
||||||
|
model.layers[il].ffn_up, NULL, NULL,
|
||||||
|
model.layers[il].ffn_gate, NULL, NULL,
|
||||||
|
model.layers[il].ffn_down, NULL, NULL,
|
||||||
|
NULL,
|
||||||
|
LLM_FFN_SILU, LLM_FFN_PAR, il);
|
||||||
|
cb(cur, "ffn_out", il);
|
||||||
|
} else {
|
||||||
|
ggml_tensor * moe_out =
|
||||||
|
build_moe_ffn(cur,
|
||||||
|
model.layers[il].ffn_gate_inp,
|
||||||
|
model.layers[il].ffn_up_exps,
|
||||||
|
model.layers[il].ffn_gate_exps,
|
||||||
|
model.layers[il].ffn_down_exps,
|
||||||
|
model.layers[il].ffn_exp_probs_b,
|
||||||
|
n_expert, n_expert_used,
|
||||||
|
LLM_FFN_SILU, hparams.expert_weights_norm,
|
||||||
|
true, hparams.expert_weights_scale,
|
||||||
|
(llama_expert_gating_func_type) hparams.expert_gating_func,
|
||||||
|
il);
|
||||||
|
cb(moe_out, "ffn_moe_out", il);
|
||||||
|
|
||||||
|
{
|
||||||
|
ggml_tensor * ffn_shexp = build_ffn(cur,
|
||||||
|
model.layers[il].ffn_up_shexp, NULL, NULL,
|
||||||
|
model.layers[il].ffn_gate_shexp, NULL, NULL,
|
||||||
|
model.layers[il].ffn_down_shexp, NULL, NULL,
|
||||||
|
NULL,
|
||||||
|
LLM_FFN_SILU, LLM_FFN_PAR, il);
|
||||||
|
cb(ffn_shexp, "ffn_shexp", il);
|
||||||
|
|
||||||
|
cur = ggml_add(ctx0, moe_out, ffn_shexp);
|
||||||
|
cb(cur, "ffn_out", il);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cur = ggml_add(ctx0, cur, sa_out);
|
||||||
|
|
||||||
|
cur = build_cvec(cur, il);
|
||||||
|
cb(cur, "l_out", il);
|
||||||
|
|
||||||
|
// input for next layer
|
||||||
|
inpL = cur;
|
||||||
|
}
|
||||||
|
|
||||||
|
cur = inpL;
|
||||||
|
|
||||||
|
cur = build_norm(cur,
|
||||||
|
model.output_norm, NULL,
|
||||||
|
LLM_NORM_RMS, -1);
|
||||||
|
|
||||||
|
cb(cur, "result_norm", -1);
|
||||||
|
res->t_embd = cur;
|
||||||
|
|
||||||
|
// lm_head
|
||||||
|
cur = build_lora_mm(model.output, cur);
|
||||||
|
|
||||||
|
cb(cur, "result_output", -1);
|
||||||
|
res->t_logits = cur;
|
||||||
|
|
||||||
|
ggml_build_forward_expand(gf, cur);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
struct llm_build_dots1 : public llm_graph_context {
|
struct llm_build_dots1 : public llm_graph_context {
|
||||||
llm_build_dots1(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
llm_build_dots1(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||||
|
|
@ -17900,6 +18167,8 @@ struct llm_build_plamo2 : public llm_graph_context_mamba {
|
||||||
cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
|
cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
|
||||||
cb(cur, "result_norm", -1);
|
cb(cur, "result_norm", -1);
|
||||||
|
|
||||||
|
res->t_embd = cur;
|
||||||
|
|
||||||
// lm_head
|
// lm_head
|
||||||
cur = build_lora_mm(model.output, cur);
|
cur = build_lora_mm(model.output, cur);
|
||||||
cb(cur, "result_output", -1);
|
cb(cur, "result_output", -1);
|
||||||
|
|
@ -19580,7 +19849,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||||
case LLM_ARCH_NOMIC_BERT_MOE:
|
case LLM_ARCH_NOMIC_BERT_MOE:
|
||||||
case LLM_ARCH_NEO_BERT:
|
case LLM_ARCH_NEO_BERT:
|
||||||
case LLM_ARCH_WAVTOKENIZER_DEC:
|
case LLM_ARCH_WAVTOKENIZER_DEC:
|
||||||
//case LLM_ARCH_GEMMA_EMBEDDING: // TODO: disabled until the cacheless SWA logic is fixed [TAG_NO_CACHE_ISWA]
|
case LLM_ARCH_GEMMA_EMBEDDING:
|
||||||
case LLM_ARCH_DREAM:
|
case LLM_ARCH_DREAM:
|
||||||
case LLM_ARCH_LLADA:
|
case LLM_ARCH_LLADA:
|
||||||
case LLM_ARCH_LLADA_MOE:
|
case LLM_ARCH_LLADA_MOE:
|
||||||
|
|
@ -19873,7 +20142,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
|
||||||
} break;
|
} break;
|
||||||
case LLM_ARCH_GEMMA_EMBEDDING:
|
case LLM_ARCH_GEMMA_EMBEDDING:
|
||||||
{
|
{
|
||||||
llm = std::make_unique<llm_build_gemma_embedding_iswa>(*this, params);
|
llm = std::make_unique<llm_build_gemma_embedding>(*this, params);
|
||||||
} break;
|
} break;
|
||||||
case LLM_ARCH_STARCODER2:
|
case LLM_ARCH_STARCODER2:
|
||||||
{
|
{
|
||||||
|
|
@ -20045,6 +20314,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
|
||||||
{
|
{
|
||||||
llm = std::make_unique<llm_build_bailingmoe>(*this, params);
|
llm = std::make_unique<llm_build_bailingmoe>(*this, params);
|
||||||
} break;
|
} break;
|
||||||
|
case LLM_ARCH_BAILINGMOE2:
|
||||||
|
{
|
||||||
|
llm = std::make_unique<llm_build_bailingmoe2>(*this, params);
|
||||||
|
} break;
|
||||||
case LLM_ARCH_SEED_OSS:
|
case LLM_ARCH_SEED_OSS:
|
||||||
{
|
{
|
||||||
llm = std::make_unique<llm_build_seed_oss>(*this, params);
|
llm = std::make_unique<llm_build_seed_oss>(*this, params);
|
||||||
|
|
@ -20220,6 +20493,7 @@ int32_t llama_n_head(const llama_model * model) {
|
||||||
llama_rope_type llama_model_rope_type(const llama_model * model) {
|
llama_rope_type llama_model_rope_type(const llama_model * model) {
|
||||||
switch (model->arch) {
|
switch (model->arch) {
|
||||||
// these models do not use RoPE
|
// these models do not use RoPE
|
||||||
|
case LLM_ARCH_CLIP:
|
||||||
case LLM_ARCH_GPT2:
|
case LLM_ARCH_GPT2:
|
||||||
case LLM_ARCH_GPTJ:
|
case LLM_ARCH_GPTJ:
|
||||||
case LLM_ARCH_MPT:
|
case LLM_ARCH_MPT:
|
||||||
|
|
@ -20311,6 +20585,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
||||||
case LLM_ARCH_EXAONE:
|
case LLM_ARCH_EXAONE:
|
||||||
case LLM_ARCH_EXAONE4:
|
case LLM_ARCH_EXAONE4:
|
||||||
case LLM_ARCH_MINICPM3:
|
case LLM_ARCH_MINICPM3:
|
||||||
|
case LLM_ARCH_BAILINGMOE2:
|
||||||
case LLM_ARCH_DOTS1:
|
case LLM_ARCH_DOTS1:
|
||||||
case LLM_ARCH_HUNYUAN_MOE:
|
case LLM_ARCH_HUNYUAN_MOE:
|
||||||
case LLM_ARCH_OPENAI_MOE:
|
case LLM_ARCH_OPENAI_MOE:
|
||||||
|
|
|
||||||
|
|
@ -108,9 +108,12 @@ enum llm_type {
|
||||||
LLM_TYPE_17B_16E, // llama4 Scout
|
LLM_TYPE_17B_16E, // llama4 Scout
|
||||||
LLM_TYPE_17B_128E, // llama4 Maverick
|
LLM_TYPE_17B_128E, // llama4 Maverick
|
||||||
LLM_TYPE_A13B,
|
LLM_TYPE_A13B,
|
||||||
|
LLM_TYPE_7B_A1B,
|
||||||
LLM_TYPE_8B_A1B, // lfm2moe
|
LLM_TYPE_8B_A1B, // lfm2moe
|
||||||
|
LLM_TYPE_16B_A1B,
|
||||||
LLM_TYPE_21B_A3B, // Ernie MoE small
|
LLM_TYPE_21B_A3B, // Ernie MoE small
|
||||||
LLM_TYPE_30B_A3B,
|
LLM_TYPE_30B_A3B,
|
||||||
|
LLM_TYPE_100B_A6B,
|
||||||
LLM_TYPE_106B_A12B, // GLM-4.5-Air
|
LLM_TYPE_106B_A12B, // GLM-4.5-Air
|
||||||
LLM_TYPE_235B_A22B,
|
LLM_TYPE_235B_A22B,
|
||||||
LLM_TYPE_300B_A47B, // Ernie MoE big
|
LLM_TYPE_300B_A47B, // Ernie MoE big
|
||||||
|
|
|
||||||
|
|
@ -701,6 +701,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool is_clip_model = false;
|
||||||
for (const auto * it : tensors) {
|
for (const auto * it : tensors) {
|
||||||
const struct ggml_tensor * tensor = it->tensor;
|
const struct ggml_tensor * tensor = it->tensor;
|
||||||
|
|
||||||
|
|
@ -714,12 +715,14 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||||
} else if (name == LLM_TN(model.arch)(LLM_TENSOR_OUTPUT, "weight")) {
|
} else if (name == LLM_TN(model.arch)(LLM_TENSOR_OUTPUT, "weight")) {
|
||||||
qs.has_output = true;
|
qs.has_output = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
is_clip_model |= name.rfind("mm.", 0) == 0; // check the "mm." prefix
|
||||||
}
|
}
|
||||||
|
|
||||||
qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)model.hparams.n_layer;
|
qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)model.hparams.n_layer;
|
||||||
|
|
||||||
// sanity checks for models that have attention layers
|
// sanity checks for models that have attention layers
|
||||||
if (qs.n_attention_wv != 0)
|
if (qs.n_attention_wv != 0 && !is_clip_model)
|
||||||
{
|
{
|
||||||
const auto & n_head_kv_iter = model.hparams.n_head_kv_arr.begin();
|
const auto & n_head_kv_iter = model.hparams.n_head_kv_arr.begin();
|
||||||
// attention layers have a non-zero number of kv heads
|
// attention layers have a non-zero number of kv heads
|
||||||
|
|
@ -881,6 +884,9 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||||
// do not quantize relative position bias (T5)
|
// do not quantize relative position bias (T5)
|
||||||
quantize &= name.find("attn_rel_b.weight") == std::string::npos;
|
quantize &= name.find("attn_rel_b.weight") == std::string::npos;
|
||||||
|
|
||||||
|
// do not quantize specific multimodal tensors
|
||||||
|
quantize &= name.find(".position_embd.") == std::string::npos;
|
||||||
|
|
||||||
ggml_type new_type;
|
ggml_type new_type;
|
||||||
void * new_data;
|
void * new_data;
|
||||||
size_t new_size;
|
size_t new_size;
|
||||||
|
|
|
||||||
|
|
@ -1957,6 +1957,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||||
clean_spaces = false;
|
clean_spaces = false;
|
||||||
} else if (
|
} else if (
|
||||||
tokenizer_pre == "bailingmoe" ||
|
tokenizer_pre == "bailingmoe" ||
|
||||||
|
tokenizer_pre == "bailingmoe2" ||
|
||||||
tokenizer_pre == "llada-moe") {
|
tokenizer_pre == "llada-moe") {
|
||||||
pre_type = LLAMA_VOCAB_PRE_TYPE_BAILINGMOE;
|
pre_type = LLAMA_VOCAB_PRE_TYPE_BAILINGMOE;
|
||||||
clean_spaces = false;
|
clean_spaces = false;
|
||||||
|
|
|
||||||
|
|
@ -124,6 +124,9 @@ static int llama_model_load(const std::string & fname, std::vector<std::string>
|
||||||
} catch(const std::exception & e) {
|
} catch(const std::exception & e) {
|
||||||
throw std::runtime_error("error loading model hyperparameters: " + std::string(e.what()));
|
throw std::runtime_error("error loading model hyperparameters: " + std::string(e.what()));
|
||||||
}
|
}
|
||||||
|
if (model.arch == LLM_ARCH_CLIP) {
|
||||||
|
throw std::runtime_error("CLIP cannot be used as main model, use it with --mmproj instead");
|
||||||
|
}
|
||||||
try {
|
try {
|
||||||
model.load_vocab(ml);
|
model.load_vocab(ml);
|
||||||
} catch(const std::exception & e) {
|
} catch(const std::exception & e) {
|
||||||
|
|
@ -314,6 +317,7 @@ struct llama_model * llama_model_load_from_splits(
|
||||||
LLAMA_LOG_ERROR("%s: list of splits is empty\n", __func__);
|
LLAMA_LOG_ERROR("%s: list of splits is empty\n", __func__);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
splits.reserve(n_paths);
|
||||||
for (size_t i = 0; i < n_paths; ++i) {
|
for (size_t i = 0; i < n_paths; ++i) {
|
||||||
splits.push_back(paths[i]);
|
splits.push_back(paths[i]);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -30,6 +30,7 @@
|
||||||
#define KEY_LAYER_NORM_EPS "clip.%s.attention.layer_norm_epsilon"
|
#define KEY_LAYER_NORM_EPS "clip.%s.attention.layer_norm_epsilon"
|
||||||
|
|
||||||
// vision-specific
|
// vision-specific
|
||||||
|
#define KEY_VISION_PROJ_TYPE "clip.vision.projector_type" // for models with mixed modalities
|
||||||
#define KEY_IMAGE_SIZE "clip.vision.image_size"
|
#define KEY_IMAGE_SIZE "clip.vision.image_size"
|
||||||
#define KEY_PREPROC_IMAGE_SIZE "clip.vision.preproc_image_size"
|
#define KEY_PREPROC_IMAGE_SIZE "clip.vision.preproc_image_size"
|
||||||
#define KEY_PATCH_SIZE "clip.vision.patch_size"
|
#define KEY_PATCH_SIZE "clip.vision.patch_size"
|
||||||
|
|
@ -48,6 +49,7 @@
|
||||||
#define KEY_MINICPMV_QUERY_NUM "clip.minicpmv_query_num"
|
#define KEY_MINICPMV_QUERY_NUM "clip.minicpmv_query_num"
|
||||||
|
|
||||||
// audio-specific
|
// audio-specific
|
||||||
|
#define KEY_AUDIO_PROJ_TYPE "clip.audio.projector_type" // for models with mixed modalities
|
||||||
#define KEY_A_NUM_MEL_BINS "clip.audio.num_mel_bins"
|
#define KEY_A_NUM_MEL_BINS "clip.audio.num_mel_bins"
|
||||||
#define KEY_A_PROJ_STACK_FACTOR "clip.audio.projector.stack_factor"
|
#define KEY_A_PROJ_STACK_FACTOR "clip.audio.projector.stack_factor"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2234,15 +2234,27 @@ struct clip_model_loader {
|
||||||
// projector type
|
// projector type
|
||||||
std::string proj_type;
|
std::string proj_type;
|
||||||
{
|
{
|
||||||
|
// default key
|
||||||
get_string(KEY_PROJ_TYPE, proj_type, false);
|
get_string(KEY_PROJ_TYPE, proj_type, false);
|
||||||
if (!proj_type.empty()) {
|
|
||||||
model.proj_type = clip_projector_type_from_string(proj_type);
|
// for models with mixed modalities
|
||||||
|
if (proj_type.empty()) {
|
||||||
|
if (modality == CLIP_MODALITY_VISION) {
|
||||||
|
get_string(KEY_VISION_PROJ_TYPE, proj_type, false);
|
||||||
|
} else if (modality == CLIP_MODALITY_AUDIO) {
|
||||||
|
get_string(KEY_AUDIO_PROJ_TYPE, proj_type, false);
|
||||||
|
} else {
|
||||||
|
GGML_ABORT("unknown modality");
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
model.proj_type = clip_projector_type_from_string(proj_type);
|
||||||
|
|
||||||
if (model.proj_type == PROJECTOR_TYPE_UNKNOWN) {
|
if (model.proj_type == PROJECTOR_TYPE_UNKNOWN) {
|
||||||
throw std::runtime_error(string_format("%s: unknown projector type: %s\n", __func__, proj_type.c_str()));
|
throw std::runtime_error(string_format("%s: unknown projector type: %s\n", __func__, proj_type.c_str()));
|
||||||
}
|
}
|
||||||
|
|
||||||
// correct arch for multimodal models
|
// correct arch for multimodal models (legacy method)
|
||||||
if (model.proj_type == PROJECTOR_TYPE_QWEN25O) {
|
if (model.proj_type == PROJECTOR_TYPE_QWEN25O) {
|
||||||
model.proj_type = modality == CLIP_MODALITY_VISION
|
model.proj_type = modality == CLIP_MODALITY_VISION
|
||||||
? PROJECTOR_TYPE_QWEN25VL
|
? PROJECTOR_TYPE_QWEN25VL
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@ problem.
|
||||||
8 files changed, 21 insertions(+), 2 deletions(-)
|
8 files changed, 21 insertions(+), 2 deletions(-)
|
||||||
|
|
||||||
diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp
|
diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp
|
||||||
index ff9135fe..8ba86f82 100644
|
index ff9135fe2..8ba86f824 100644
|
||||||
--- a/ggml/src/ggml-backend.cpp
|
--- a/ggml/src/ggml-backend.cpp
|
||||||
+++ b/ggml/src/ggml-backend.cpp
|
+++ b/ggml/src/ggml-backend.cpp
|
||||||
@@ -113,7 +113,6 @@ void ggml_backend_buffer_free(ggml_backend_buffer_t buffer) {
|
@@ -113,7 +113,6 @@ void ggml_backend_buffer_free(ggml_backend_buffer_t buffer) {
|
||||||
|
|
@ -64,18 +64,18 @@ index ff9135fe..8ba86f82 100644
|
||||||
/* .init_tensor = */ NULL, // no initialization required
|
/* .init_tensor = */ NULL, // no initialization required
|
||||||
/* .memset_tensor = */ ggml_backend_cpu_buffer_memset_tensor,
|
/* .memset_tensor = */ ggml_backend_cpu_buffer_memset_tensor,
|
||||||
diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp
|
diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp
|
||||||
index ad1adba6..7d44f74f 100755
|
index 8bd5449f1..01e2df61a 100644
|
||||||
--- a/ggml/src/ggml-cann/ggml-cann.cpp
|
--- a/ggml/src/ggml-cann/ggml-cann.cpp
|
||||||
+++ b/ggml/src/ggml-cann/ggml-cann.cpp
|
+++ b/ggml/src/ggml-cann/ggml-cann.cpp
|
||||||
@@ -843,6 +843,7 @@ static void ggml_backend_cann_buffer_free_buffer(
|
@@ -820,6 +820,7 @@ static bool ggml_backend_buffer_is_cann(ggml_backend_buffer_t buffer) {
|
||||||
ggml_backend_cann_buffer_context* ctx =
|
static void ggml_backend_cann_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
||||||
(ggml_backend_cann_buffer_context*)buffer->context;
|
ggml_backend_cann_buffer_context * ctx = (ggml_backend_cann_buffer_context *) buffer->context;
|
||||||
delete ctx;
|
delete ctx;
|
||||||
+ delete buffer;
|
+ delete buffer;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -1630,6 +1631,7 @@ static const char * ggml_backend_cann_host_buffer_name(ggml_backend_buffer_t buf
|
@@ -1560,6 +1561,7 @@ static const char * ggml_backend_cann_host_buffer_name(ggml_backend_buffer_t buf
|
||||||
*/
|
*/
|
||||||
static void ggml_backend_cann_host_buffer_free(ggml_backend_buffer_t buffer) {
|
static void ggml_backend_cann_host_buffer_free(ggml_backend_buffer_t buffer) {
|
||||||
ACL_CHECK(aclrtFreeHost(buffer->context));
|
ACL_CHECK(aclrtFreeHost(buffer->context));
|
||||||
|
|
@ -84,10 +84,10 @@ index ad1adba6..7d44f74f 100755
|
||||||
|
|
||||||
/**
|
/**
|
||||||
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
|
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||||
index 856e9de2..c0b1e4c1 100644
|
index bc396b521..aefc6935e 100644
|
||||||
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
|
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||||
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
|
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||||
@@ -567,6 +567,7 @@ struct ggml_backend_cuda_buffer_context {
|
@@ -576,6 +576,7 @@ struct ggml_backend_cuda_buffer_context {
|
||||||
static void ggml_backend_cuda_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
static void ggml_backend_cuda_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
||||||
ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
|
ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
|
||||||
delete ctx;
|
delete ctx;
|
||||||
|
|
@ -95,7 +95,7 @@ index 856e9de2..c0b1e4c1 100644
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool ggml_backend_buffer_is_cuda(ggml_backend_buffer_t buffer) {
|
static bool ggml_backend_buffer_is_cuda(ggml_backend_buffer_t buffer) {
|
||||||
@@ -822,6 +823,7 @@ struct ggml_backend_cuda_split_buffer_context {
|
@@ -831,6 +832,7 @@ struct ggml_backend_cuda_split_buffer_context {
|
||||||
static void ggml_backend_cuda_split_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
static void ggml_backend_cuda_split_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
||||||
ggml_backend_cuda_split_buffer_context * ctx = (ggml_backend_cuda_split_buffer_context *)buffer->context;
|
ggml_backend_cuda_split_buffer_context * ctx = (ggml_backend_cuda_split_buffer_context *)buffer->context;
|
||||||
delete ctx;
|
delete ctx;
|
||||||
|
|
@ -103,7 +103,7 @@ index 856e9de2..c0b1e4c1 100644
|
||||||
}
|
}
|
||||||
|
|
||||||
static void * ggml_backend_cuda_split_buffer_get_base(ggml_backend_buffer_t buffer) {
|
static void * ggml_backend_cuda_split_buffer_get_base(ggml_backend_buffer_t buffer) {
|
||||||
@@ -1103,6 +1105,7 @@ static bool ggml_backend_buft_is_cuda_host(ggml_backend_buffer_type_t buft) {
|
@@ -1112,6 +1114,7 @@ static bool ggml_backend_buft_is_cuda_host(ggml_backend_buffer_type_t buft) {
|
||||||
|
|
||||||
static void ggml_backend_cuda_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
static void ggml_backend_cuda_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
||||||
CUDA_CHECK(cudaFreeHost(buffer->context));
|
CUDA_CHECK(cudaFreeHost(buffer->context));
|
||||||
|
|
@ -112,7 +112,7 @@ index 856e9de2..c0b1e4c1 100644
|
||||||
|
|
||||||
static void * ggml_cuda_host_malloc(size_t size) {
|
static void * ggml_cuda_host_malloc(size_t size) {
|
||||||
diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp
|
diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp
|
||||||
index 7afc881f..bf096227 100644
|
index 7afc881fa..bf0962274 100644
|
||||||
--- a/ggml/src/ggml-metal/ggml-metal.cpp
|
--- a/ggml/src/ggml-metal/ggml-metal.cpp
|
||||||
+++ b/ggml/src/ggml-metal/ggml-metal.cpp
|
+++ b/ggml/src/ggml-metal/ggml-metal.cpp
|
||||||
@@ -25,6 +25,7 @@ static void ggml_backend_metal_buffer_shared_free_buffer(ggml_backend_buffer_t b
|
@@ -25,6 +25,7 @@ static void ggml_backend_metal_buffer_shared_free_buffer(ggml_backend_buffer_t b
|
||||||
|
|
@ -132,10 +132,10 @@ index 7afc881f..bf096227 100644
|
||||||
|
|
||||||
static void * ggml_backend_metal_buffer_private_get_base(ggml_backend_buffer_t buffer) {
|
static void * ggml_backend_metal_buffer_private_get_base(ggml_backend_buffer_t buffer) {
|
||||||
diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp
|
diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp
|
||||||
index 79d21487..38c75018 100644
|
index db33a4ab6..c42ee26e1 100644
|
||||||
--- a/ggml/src/ggml-opencl/ggml-opencl.cpp
|
--- a/ggml/src/ggml-opencl/ggml-opencl.cpp
|
||||||
+++ b/ggml/src/ggml-opencl/ggml-opencl.cpp
|
+++ b/ggml/src/ggml-opencl/ggml-opencl.cpp
|
||||||
@@ -3212,6 +3212,7 @@ struct ggml_backend_opencl_buffer_context {
|
@@ -3266,6 +3266,7 @@ struct ggml_backend_opencl_buffer_context {
|
||||||
static void ggml_backend_opencl_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
static void ggml_backend_opencl_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
||||||
ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context;
|
ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context;
|
||||||
delete ctx;
|
delete ctx;
|
||||||
|
|
@ -144,7 +144,7 @@ index 79d21487..38c75018 100644
|
||||||
|
|
||||||
static void * ggml_backend_opencl_buffer_get_base(ggml_backend_buffer_t buffer) {
|
static void * ggml_backend_opencl_buffer_get_base(ggml_backend_buffer_t buffer) {
|
||||||
diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp
|
diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp
|
||||||
index aad48d62..a46c0f52 100644
|
index a38df5a97..fd07e4a21 100644
|
||||||
--- a/ggml/src/ggml-rpc/ggml-rpc.cpp
|
--- a/ggml/src/ggml-rpc/ggml-rpc.cpp
|
||||||
+++ b/ggml/src/ggml-rpc/ggml-rpc.cpp
|
+++ b/ggml/src/ggml-rpc/ggml-rpc.cpp
|
||||||
@@ -528,6 +528,7 @@ static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
@@ -528,6 +528,7 @@ static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
||||||
|
|
@ -156,10 +156,10 @@ index aad48d62..a46c0f52 100644
|
||||||
|
|
||||||
static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
|
static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
|
||||||
diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp
|
diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp
|
||||||
index 45b8c216..4ec9a592 100644
|
index b695ba051..37e853120 100644
|
||||||
--- a/ggml/src/ggml-sycl/ggml-sycl.cpp
|
--- a/ggml/src/ggml-sycl/ggml-sycl.cpp
|
||||||
+++ b/ggml/src/ggml-sycl/ggml-sycl.cpp
|
+++ b/ggml/src/ggml-sycl/ggml-sycl.cpp
|
||||||
@@ -334,6 +334,7 @@ ggml_backend_sycl_buffer_free_buffer(ggml_backend_buffer_t buffer) try {
|
@@ -352,6 +352,7 @@ ggml_backend_sycl_buffer_free_buffer(ggml_backend_buffer_t buffer) try {
|
||||||
ggml_sycl_set_device(ctx->device);
|
ggml_sycl_set_device(ctx->device);
|
||||||
|
|
||||||
delete ctx;
|
delete ctx;
|
||||||
|
|
@ -167,7 +167,7 @@ index 45b8c216..4ec9a592 100644
|
||||||
}
|
}
|
||||||
catch (sycl::exception const &exc) {
|
catch (sycl::exception const &exc) {
|
||||||
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
|
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
|
||||||
@@ -795,6 +796,7 @@ struct ggml_backend_sycl_split_buffer_context {
|
@@ -813,6 +814,7 @@ struct ggml_backend_sycl_split_buffer_context {
|
||||||
static void ggml_backend_sycl_split_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
static void ggml_backend_sycl_split_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
||||||
ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context;
|
ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context;
|
||||||
delete ctx;
|
delete ctx;
|
||||||
|
|
@ -175,7 +175,7 @@ index 45b8c216..4ec9a592 100644
|
||||||
}
|
}
|
||||||
|
|
||||||
static void * ggml_backend_sycl_split_buffer_get_base(ggml_backend_buffer_t buffer) {
|
static void * ggml_backend_sycl_split_buffer_get_base(ggml_backend_buffer_t buffer) {
|
||||||
@@ -1137,6 +1139,7 @@ static const char * ggml_backend_sycl_host_buffer_type_name(ggml_backend_buffer_
|
@@ -1155,6 +1157,7 @@ static const char * ggml_backend_sycl_host_buffer_type_name(ggml_backend_buffer_
|
||||||
|
|
||||||
static void ggml_backend_sycl_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
static void ggml_backend_sycl_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
||||||
ggml_sycl_host_free(buffer->context);
|
ggml_sycl_host_free(buffer->context);
|
||||||
|
|
@ -184,10 +184,10 @@ index 45b8c216..4ec9a592 100644
|
||||||
|
|
||||||
static ggml_backend_buffer_t ggml_backend_sycl_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
static ggml_backend_buffer_t ggml_backend_sycl_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
||||||
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||||
index 3cd89c71..ed83236f 100644
|
index b783f7805..216dc167c 100644
|
||||||
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||||
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||||
@@ -11600,6 +11600,7 @@ static void ggml_backend_vk_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
@@ -11828,6 +11828,7 @@ static void ggml_backend_vk_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
||||||
ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context;
|
ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context;
|
||||||
ggml_vk_destroy_buffer(ctx->dev_buffer);
|
ggml_vk_destroy_buffer(ctx->dev_buffer);
|
||||||
delete ctx;
|
delete ctx;
|
||||||
|
|
@ -195,7 +195,7 @@ index 3cd89c71..ed83236f 100644
|
||||||
}
|
}
|
||||||
|
|
||||||
static void * ggml_backend_vk_buffer_get_base(ggml_backend_buffer_t buffer) {
|
static void * ggml_backend_vk_buffer_get_base(ggml_backend_buffer_t buffer) {
|
||||||
@@ -11743,6 +11744,7 @@ static const char * ggml_backend_vk_host_buffer_name(ggml_backend_buffer_t buffe
|
@@ -11971,6 +11972,7 @@ static const char * ggml_backend_vk_host_buffer_name(ggml_backend_buffer_t buffe
|
||||||
static void ggml_backend_vk_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
static void ggml_backend_vk_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
||||||
VK_LOG_MEMORY("ggml_backend_vk_host_buffer_free_buffer()");
|
VK_LOG_MEMORY("ggml_backend_vk_host_buffer_free_buffer()");
|
||||||
ggml_vk_host_free(vk_instance.devices[0], buffer->context);
|
ggml_vk_host_free(vk_instance.devices[0], buffer->context);
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ logs instead of throwing an error
|
||||||
1 file changed, 3 insertions(+), 11 deletions(-)
|
1 file changed, 3 insertions(+), 11 deletions(-)
|
||||||
|
|
||||||
diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp
|
diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp
|
||||||
index 7fffd171..0b6edaf4 100644
|
index 639fecbd3..a7ce6f8e1 100644
|
||||||
--- a/src/llama-vocab.cpp
|
--- a/src/llama-vocab.cpp
|
||||||
+++ b/src/llama-vocab.cpp
|
+++ b/src/llama-vocab.cpp
|
||||||
@@ -1812,16 +1812,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
@@ -1812,16 +1812,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||||
|
|
@ -31,7 +31,7 @@ index 7fffd171..0b6edaf4 100644
|
||||||
pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
|
pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
|
||||||
} else if (
|
} else if (
|
||||||
tokenizer_pre == "llama3" ||
|
tokenizer_pre == "llama3" ||
|
||||||
@@ -1992,7 +1983,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
@@ -1993,7 +1984,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||||
pre_type = LLAMA_VOCAB_PRE_TYPE_GROK_2;
|
pre_type = LLAMA_VOCAB_PRE_TYPE_GROK_2;
|
||||||
clean_spaces = false;
|
clean_spaces = false;
|
||||||
} else {
|
} else {
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ filesystems for paths that include wide characters
|
||||||
1 file changed, 39 insertions(+)
|
1 file changed, 39 insertions(+)
|
||||||
|
|
||||||
diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp
|
diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp
|
||||||
index 98e68af2..6699b75a 100644
|
index f2abf8852..c984e6282 100644
|
||||||
--- a/tools/mtmd/clip.cpp
|
--- a/tools/mtmd/clip.cpp
|
||||||
+++ b/tools/mtmd/clip.cpp
|
+++ b/tools/mtmd/clip.cpp
|
||||||
@@ -28,6 +28,19 @@
|
@@ -28,6 +28,19 @@
|
||||||
|
|
@ -33,7 +33,7 @@ index 98e68af2..6699b75a 100644
|
||||||
struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callback_default, NULL};
|
struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callback_default, NULL};
|
||||||
|
|
||||||
enum ffn_op_type {
|
enum ffn_op_type {
|
||||||
@@ -2762,7 +2775,29 @@ struct clip_model_loader {
|
@@ -2774,7 +2787,29 @@ struct clip_model_loader {
|
||||||
{
|
{
|
||||||
std::vector<uint8_t> read_buf;
|
std::vector<uint8_t> read_buf;
|
||||||
|
|
||||||
|
|
@ -63,7 +63,7 @@ index 98e68af2..6699b75a 100644
|
||||||
if (!fin) {
|
if (!fin) {
|
||||||
throw std::runtime_error(string_format("%s: failed to open %s\n", __func__, fname.c_str()));
|
throw std::runtime_error(string_format("%s: failed to open %s\n", __func__, fname.c_str()));
|
||||||
}
|
}
|
||||||
@@ -2789,7 +2824,11 @@ struct clip_model_loader {
|
@@ -2801,7 +2836,11 @@ struct clip_model_loader {
|
||||||
ggml_backend_tensor_set(cur, read_buf.data(), 0, num_bytes);
|
ggml_backend_tensor_set(cur, read_buf.data(), 0, num_bytes);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -15,10 +15,10 @@ adds support for the Solar Pro architecture
|
||||||
7 files changed, 248 insertions(+), 1 deletion(-)
|
7 files changed, 248 insertions(+), 1 deletion(-)
|
||||||
|
|
||||||
diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp
|
diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp
|
||||||
index 869e4dcc..9f6b6ad2 100644
|
index 8ca769c5f..ab262ec0c 100644
|
||||||
--- a/src/llama-arch.cpp
|
--- a/src/llama-arch.cpp
|
||||||
+++ b/src/llama-arch.cpp
|
+++ b/src/llama-arch.cpp
|
||||||
@@ -81,6 +81,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
@@ -82,6 +82,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||||
{ LLM_ARCH_GRANITE_MOE, "granitemoe" },
|
{ LLM_ARCH_GRANITE_MOE, "granitemoe" },
|
||||||
{ LLM_ARCH_GRANITE_HYBRID, "granitehybrid" },
|
{ LLM_ARCH_GRANITE_HYBRID, "granitehybrid" },
|
||||||
{ LLM_ARCH_CHAMELEON, "chameleon" },
|
{ LLM_ARCH_CHAMELEON, "chameleon" },
|
||||||
|
|
@ -26,7 +26,7 @@ index 869e4dcc..9f6b6ad2 100644
|
||||||
{ LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
|
{ LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
|
||||||
{ LLM_ARCH_PLM, "plm" },
|
{ LLM_ARCH_PLM, "plm" },
|
||||||
{ LLM_ARCH_BAILINGMOE, "bailingmoe" },
|
{ LLM_ARCH_BAILINGMOE, "bailingmoe" },
|
||||||
@@ -179,6 +180,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
@@ -183,6 +184,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||||
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
|
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
|
||||||
{ LLM_KV_ATTENTION_OUTPUT_SCALE, "%s.attention.output_scale" },
|
{ LLM_KV_ATTENTION_OUTPUT_SCALE, "%s.attention.output_scale" },
|
||||||
{ LLM_KV_ATTENTION_TEMPERATURE_LENGTH, "%s.attention.temperature_length" },
|
{ LLM_KV_ATTENTION_TEMPERATURE_LENGTH, "%s.attention.temperature_length" },
|
||||||
|
|
@ -34,7 +34,7 @@ index 869e4dcc..9f6b6ad2 100644
|
||||||
{ LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
|
{ LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
|
||||||
{ LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },
|
{ LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },
|
||||||
|
|
||||||
@@ -1893,6 +1895,24 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
@@ -1901,6 +1903,24 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||||
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -59,7 +59,7 @@ index 869e4dcc..9f6b6ad2 100644
|
||||||
{
|
{
|
||||||
LLM_ARCH_WAVTOKENIZER_DEC,
|
LLM_ARCH_WAVTOKENIZER_DEC,
|
||||||
{
|
{
|
||||||
@@ -2429,6 +2449,7 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
@@ -2469,6 +2489,7 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
||||||
{LLM_TENSOR_LAUREL_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
{LLM_TENSOR_LAUREL_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||||
// this tensor is loaded for T5, but never used
|
// this tensor is loaded for T5, but never used
|
||||||
{LLM_TENSOR_DEC_CROSS_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}},
|
{LLM_TENSOR_DEC_CROSS_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}},
|
||||||
|
|
@ -68,10 +68,10 @@ index 869e4dcc..9f6b6ad2 100644
|
||||||
{LLM_TENSOR_POS_NET_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
{LLM_TENSOR_POS_NET_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||||
{LLM_TENSOR_POS_NET_NORM1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
{LLM_TENSOR_POS_NET_NORM1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||||
diff --git a/src/llama-arch.h b/src/llama-arch.h
|
diff --git a/src/llama-arch.h b/src/llama-arch.h
|
||||||
index c3ae7165..dc7a362a 100644
|
index dea725c1a..ea2b4ffb9 100644
|
||||||
--- a/src/llama-arch.h
|
--- a/src/llama-arch.h
|
||||||
+++ b/src/llama-arch.h
|
+++ b/src/llama-arch.h
|
||||||
@@ -85,6 +85,7 @@ enum llm_arch {
|
@@ -86,6 +86,7 @@ enum llm_arch {
|
||||||
LLM_ARCH_GRANITE_MOE,
|
LLM_ARCH_GRANITE_MOE,
|
||||||
LLM_ARCH_GRANITE_HYBRID,
|
LLM_ARCH_GRANITE_HYBRID,
|
||||||
LLM_ARCH_CHAMELEON,
|
LLM_ARCH_CHAMELEON,
|
||||||
|
|
@ -79,7 +79,7 @@ index c3ae7165..dc7a362a 100644
|
||||||
LLM_ARCH_WAVTOKENIZER_DEC,
|
LLM_ARCH_WAVTOKENIZER_DEC,
|
||||||
LLM_ARCH_PLM,
|
LLM_ARCH_PLM,
|
||||||
LLM_ARCH_BAILINGMOE,
|
LLM_ARCH_BAILINGMOE,
|
||||||
@@ -183,6 +184,7 @@ enum llm_kv {
|
@@ -187,6 +188,7 @@ enum llm_kv {
|
||||||
LLM_KV_ATTENTION_SCALE,
|
LLM_KV_ATTENTION_SCALE,
|
||||||
LLM_KV_ATTENTION_OUTPUT_SCALE,
|
LLM_KV_ATTENTION_OUTPUT_SCALE,
|
||||||
LLM_KV_ATTENTION_TEMPERATURE_LENGTH,
|
LLM_KV_ATTENTION_TEMPERATURE_LENGTH,
|
||||||
|
|
@ -87,7 +87,7 @@ index c3ae7165..dc7a362a 100644
|
||||||
LLM_KV_ATTENTION_KEY_LENGTH_MLA,
|
LLM_KV_ATTENTION_KEY_LENGTH_MLA,
|
||||||
LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
|
LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
|
||||||
|
|
||||||
@@ -432,6 +434,7 @@ enum llm_tensor {
|
@@ -436,6 +438,7 @@ enum llm_tensor {
|
||||||
LLM_TENSOR_ENC_OUTPUT_NORM,
|
LLM_TENSOR_ENC_OUTPUT_NORM,
|
||||||
LLM_TENSOR_CLS,
|
LLM_TENSOR_CLS,
|
||||||
LLM_TENSOR_CLS_OUT,
|
LLM_TENSOR_CLS_OUT,
|
||||||
|
|
@ -96,7 +96,7 @@ index c3ae7165..dc7a362a 100644
|
||||||
LLM_TENSOR_CONVNEXT_DW,
|
LLM_TENSOR_CONVNEXT_DW,
|
||||||
LLM_TENSOR_CONVNEXT_NORM,
|
LLM_TENSOR_CONVNEXT_NORM,
|
||||||
diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp
|
diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp
|
||||||
index db65d69e..b6bf6bbf 100644
|
index db65d69ea..b6bf6bbf2 100644
|
||||||
--- a/src/llama-hparams.cpp
|
--- a/src/llama-hparams.cpp
|
||||||
+++ b/src/llama-hparams.cpp
|
+++ b/src/llama-hparams.cpp
|
||||||
@@ -151,6 +151,14 @@ uint32_t llama_hparams::n_pos_per_embd() const {
|
@@ -151,6 +151,14 @@ uint32_t llama_hparams::n_pos_per_embd() const {
|
||||||
|
|
@ -115,7 +115,7 @@ index db65d69e..b6bf6bbf 100644
|
||||||
if (il < n_layer) {
|
if (il < n_layer) {
|
||||||
return swa_layers[il];
|
return swa_layers[il];
|
||||||
diff --git a/src/llama-hparams.h b/src/llama-hparams.h
|
diff --git a/src/llama-hparams.h b/src/llama-hparams.h
|
||||||
index 4e7f73ec..80582728 100644
|
index 6fcf91b7d..24569a258 100644
|
||||||
--- a/src/llama-hparams.h
|
--- a/src/llama-hparams.h
|
||||||
+++ b/src/llama-hparams.h
|
+++ b/src/llama-hparams.h
|
||||||
@@ -64,6 +64,8 @@ struct llama_hparams {
|
@@ -64,6 +64,8 @@ struct llama_hparams {
|
||||||
|
|
@ -127,7 +127,7 @@ index 4e7f73ec..80582728 100644
|
||||||
uint32_t n_layer_dense_lead = 0;
|
uint32_t n_layer_dense_lead = 0;
|
||||||
uint32_t n_lora_q = 0;
|
uint32_t n_lora_q = 0;
|
||||||
uint32_t n_lora_kv = 0;
|
uint32_t n_lora_kv = 0;
|
||||||
@@ -248,6 +250,9 @@ struct llama_hparams {
|
@@ -250,6 +252,9 @@ struct llama_hparams {
|
||||||
|
|
||||||
uint32_t n_pos_per_embd() const;
|
uint32_t n_pos_per_embd() const;
|
||||||
|
|
||||||
|
|
@ -138,7 +138,7 @@ index 4e7f73ec..80582728 100644
|
||||||
|
|
||||||
bool has_kv(uint32_t il) const;
|
bool has_kv(uint32_t il) const;
|
||||||
diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp
|
diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp
|
||||||
index aa3a65f8..ee303bd5 100644
|
index aa3a65f87..ee303bd58 100644
|
||||||
--- a/src/llama-model-loader.cpp
|
--- a/src/llama-model-loader.cpp
|
||||||
+++ b/src/llama-model-loader.cpp
|
+++ b/src/llama-model-loader.cpp
|
||||||
@@ -466,7 +466,7 @@ namespace GGUFMeta {
|
@@ -466,7 +466,7 @@ namespace GGUFMeta {
|
||||||
|
|
@ -151,10 +151,10 @@ index aa3a65f8..ee303bd5 100644
|
||||||
llama_model_loader::llama_model_loader(
|
llama_model_loader::llama_model_loader(
|
||||||
const std::string & fname,
|
const std::string & fname,
|
||||||
diff --git a/src/llama-model.cpp b/src/llama-model.cpp
|
diff --git a/src/llama-model.cpp b/src/llama-model.cpp
|
||||||
index 36d495d6..74e1d162 100644
|
index 2a83d6627..54621ea39 100644
|
||||||
--- a/src/llama-model.cpp
|
--- a/src/llama-model.cpp
|
||||||
+++ b/src/llama-model.cpp
|
+++ b/src/llama-model.cpp
|
||||||
@@ -1865,6 +1865,21 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
@@ -1890,6 +1890,21 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||||
default: type = LLM_TYPE_UNKNOWN;
|
default: type = LLM_TYPE_UNKNOWN;
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
|
@ -176,7 +176,7 @@ index 36d495d6..74e1d162 100644
|
||||||
case LLM_ARCH_WAVTOKENIZER_DEC:
|
case LLM_ARCH_WAVTOKENIZER_DEC:
|
||||||
{
|
{
|
||||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
||||||
@@ -5170,6 +5185,34 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
@@ -5224,6 +5239,34 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||||
|
|
||||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
||||||
|
|
||||||
|
|
@ -211,7 +211,7 @@ index 36d495d6..74e1d162 100644
|
||||||
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
||||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
||||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
||||||
@@ -16392,6 +16435,165 @@ struct llm_build_granite_hybrid : public llm_graph_context_mamba {
|
@@ -16515,6 +16558,165 @@ struct llm_build_granite_hybrid : public llm_graph_context_mamba {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -377,7 +377,7 @@ index 36d495d6..74e1d162 100644
|
||||||
// ref: https://github.com/facebookresearch/chameleon
|
// ref: https://github.com/facebookresearch/chameleon
|
||||||
// based on the original build_llama() function, changes:
|
// based on the original build_llama() function, changes:
|
||||||
// * qk-norm
|
// * qk-norm
|
||||||
@@ -19827,6 +20029,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
|
@@ -20096,6 +20298,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
|
||||||
{
|
{
|
||||||
llm = std::make_unique<llm_build_chameleon>(*this, params);
|
llm = std::make_unique<llm_build_chameleon>(*this, params);
|
||||||
} break;
|
} break;
|
||||||
|
|
@ -388,7 +388,7 @@ index 36d495d6..74e1d162 100644
|
||||||
case LLM_ARCH_WAVTOKENIZER_DEC:
|
case LLM_ARCH_WAVTOKENIZER_DEC:
|
||||||
{
|
{
|
||||||
llm = std::make_unique<llm_build_wavtokenizer_dec>(*this, params);
|
llm = std::make_unique<llm_build_wavtokenizer_dec>(*this, params);
|
||||||
@@ -20057,6 +20263,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
@@ -20331,6 +20537,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
||||||
case LLM_ARCH_GRANITE_MOE:
|
case LLM_ARCH_GRANITE_MOE:
|
||||||
case LLM_ARCH_GRANITE_HYBRID:
|
case LLM_ARCH_GRANITE_HYBRID:
|
||||||
case LLM_ARCH_CHAMELEON:
|
case LLM_ARCH_CHAMELEON:
|
||||||
|
|
@ -397,7 +397,7 @@ index 36d495d6..74e1d162 100644
|
||||||
case LLM_ARCH_NEO_BERT:
|
case LLM_ARCH_NEO_BERT:
|
||||||
case LLM_ARCH_SMOLLM3:
|
case LLM_ARCH_SMOLLM3:
|
||||||
diff --git a/src/llama-model.h b/src/llama-model.h
|
diff --git a/src/llama-model.h b/src/llama-model.h
|
||||||
index 7f48662f..ec3fbd33 100644
|
index 248f85410..4a7924aaa 100644
|
||||||
--- a/src/llama-model.h
|
--- a/src/llama-model.h
|
||||||
+++ b/src/llama-model.h
|
+++ b/src/llama-model.h
|
||||||
@@ -76,6 +76,7 @@ enum llm_type {
|
@@ -76,6 +76,7 @@ enum llm_type {
|
||||||
|
|
@ -408,7 +408,7 @@ index 7f48662f..ec3fbd33 100644
|
||||||
LLM_TYPE_27B,
|
LLM_TYPE_27B,
|
||||||
LLM_TYPE_30B,
|
LLM_TYPE_30B,
|
||||||
LLM_TYPE_32B,
|
LLM_TYPE_32B,
|
||||||
@@ -387,6 +388,8 @@ struct llama_layer {
|
@@ -390,6 +391,8 @@ struct llama_layer {
|
||||||
struct ggml_tensor * ffn_act_beta = nullptr;
|
struct ggml_tensor * ffn_act_beta = nullptr;
|
||||||
struct ggml_tensor * ffn_act_eps = nullptr;
|
struct ggml_tensor * ffn_act_eps = nullptr;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ regex
|
||||||
2 files changed, 22 insertions(+), 1 deletion(-)
|
2 files changed, 22 insertions(+), 1 deletion(-)
|
||||||
|
|
||||||
diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp
|
diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp
|
||||||
index 0b6edaf4..3de95c67 100644
|
index a7ce6f8e1..8064dc197 100644
|
||||||
--- a/src/llama-vocab.cpp
|
--- a/src/llama-vocab.cpp
|
||||||
+++ b/src/llama-vocab.cpp
|
+++ b/src/llama-vocab.cpp
|
||||||
@@ -299,7 +299,7 @@ struct llm_tokenizer_bpe : llm_tokenizer {
|
@@ -299,7 +299,7 @@ struct llm_tokenizer_bpe : llm_tokenizer {
|
||||||
|
|
@ -25,7 +25,7 @@ index 0b6edaf4..3de95c67 100644
|
||||||
"\\s+$",
|
"\\s+$",
|
||||||
"[一-龥ࠀ-一가-]+",
|
"[一-龥ࠀ-一가-]+",
|
||||||
diff --git a/src/unicode.cpp b/src/unicode.cpp
|
diff --git a/src/unicode.cpp b/src/unicode.cpp
|
||||||
index 65f36651..ce336a22 100644
|
index 65f366517..ce336a228 100644
|
||||||
--- a/src/unicode.cpp
|
--- a/src/unicode.cpp
|
||||||
+++ b/src/unicode.cpp
|
+++ b/src/unicode.cpp
|
||||||
@@ -2,6 +2,11 @@
|
@@ -2,6 +2,11 @@
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ Subject: [PATCH] maintain ordering for rules for grammar
|
||||||
1 file changed, 1 insertion(+), 1 deletion(-)
|
1 file changed, 1 insertion(+), 1 deletion(-)
|
||||||
|
|
||||||
diff --git a/common/json-schema-to-grammar.cpp b/common/json-schema-to-grammar.cpp
|
diff --git a/common/json-schema-to-grammar.cpp b/common/json-schema-to-grammar.cpp
|
||||||
index db1f0b23..f4de7e34 100644
|
index dd9b51a9e..d88f43209 100644
|
||||||
--- a/common/json-schema-to-grammar.cpp
|
--- a/common/json-schema-to-grammar.cpp
|
||||||
+++ b/common/json-schema-to-grammar.cpp
|
+++ b/common/json-schema-to-grammar.cpp
|
||||||
@@ -308,7 +308,7 @@ private:
|
@@ -308,7 +308,7 @@ private:
|
||||||
|
|
|
||||||
|
|
@ -11,10 +11,10 @@ with the fastest acceleration is loaded
|
||||||
1 file changed, 13 insertions(+), 8 deletions(-)
|
1 file changed, 13 insertions(+), 8 deletions(-)
|
||||||
|
|
||||||
diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp
|
diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp
|
||||||
index 136afec7..f794d9cf 100644
|
index e96b5c403..a55d9b280 100644
|
||||||
--- a/ggml/src/ggml-backend-reg.cpp
|
--- a/ggml/src/ggml-backend-reg.cpp
|
||||||
+++ b/ggml/src/ggml-backend-reg.cpp
|
+++ b/ggml/src/ggml-backend-reg.cpp
|
||||||
@@ -175,7 +175,7 @@ struct ggml_backend_reg_entry {
|
@@ -179,7 +179,7 @@ struct ggml_backend_reg_entry {
|
||||||
|
|
||||||
struct ggml_backend_registry {
|
struct ggml_backend_registry {
|
||||||
std::vector<ggml_backend_reg_entry> backends;
|
std::vector<ggml_backend_reg_entry> backends;
|
||||||
|
|
@ -23,7 +23,7 @@ index 136afec7..f794d9cf 100644
|
||||||
|
|
||||||
ggml_backend_registry() {
|
ggml_backend_registry() {
|
||||||
#ifdef GGML_USE_CUDA
|
#ifdef GGML_USE_CUDA
|
||||||
@@ -223,7 +223,7 @@ struct ggml_backend_registry {
|
@@ -230,7 +230,7 @@ struct ggml_backend_registry {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -32,7 +32,7 @@ index 136afec7..f794d9cf 100644
|
||||||
if (!reg) {
|
if (!reg) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -234,15 +234,20 @@ struct ggml_backend_registry {
|
@@ -241,15 +241,20 @@ struct ggml_backend_registry {
|
||||||
#endif
|
#endif
|
||||||
backends.push_back({ reg, std::move(handle) });
|
backends.push_back({ reg, std::move(handle) });
|
||||||
for (size_t i = 0; i < ggml_backend_reg_dev_count(reg); i++) {
|
for (size_t i = 0; i < ggml_backend_reg_dev_count(reg); i++) {
|
||||||
|
|
@ -56,7 +56,7 @@ index 136afec7..f794d9cf 100644
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_backend_reg_t load_backend(const fs::path & path, bool silent) {
|
ggml_backend_reg_t load_backend(const fs::path & path, bool silent) {
|
||||||
@@ -286,7 +291,7 @@ struct ggml_backend_registry {
|
@@ -293,7 +298,7 @@ struct ggml_backend_registry {
|
||||||
|
|
||||||
GGML_LOG_INFO("%s: loaded %s backend from %s\n", __func__, ggml_backend_reg_name(reg), path_str(path).c_str());
|
GGML_LOG_INFO("%s: loaded %s backend from %s\n", __func__, ggml_backend_reg_name(reg), path_str(path).c_str());
|
||||||
|
|
||||||
|
|
@ -65,7 +65,7 @@ index 136afec7..f794d9cf 100644
|
||||||
|
|
||||||
return reg;
|
return reg;
|
||||||
}
|
}
|
||||||
@@ -309,7 +314,7 @@ struct ggml_backend_registry {
|
@@ -316,7 +321,7 @@ struct ggml_backend_registry {
|
||||||
// remove devices
|
// remove devices
|
||||||
devices.erase(
|
devices.erase(
|
||||||
std::remove_if(devices.begin(), devices.end(),
|
std::remove_if(devices.begin(), devices.end(),
|
||||||
|
|
@ -74,7 +74,7 @@ index 136afec7..f794d9cf 100644
|
||||||
devices.end());
|
devices.end());
|
||||||
|
|
||||||
// remove backend
|
// remove backend
|
||||||
@@ -367,7 +372,7 @@ size_t ggml_backend_dev_count() {
|
@@ -374,7 +379,7 @@ size_t ggml_backend_dev_count() {
|
||||||
|
|
||||||
ggml_backend_dev_t ggml_backend_dev_get(size_t index) {
|
ggml_backend_dev_t ggml_backend_dev_get(size_t index) {
|
||||||
GGML_ASSERT(index < ggml_backend_dev_count());
|
GGML_ASSERT(index < ggml_backend_dev_count());
|
||||||
|
|
|
||||||
|
|
@ -8,10 +8,10 @@ Subject: [PATCH] add phony target ggml-cpu for all cpu variants
|
||||||
1 file changed, 2 insertions(+)
|
1 file changed, 2 insertions(+)
|
||||||
|
|
||||||
diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt
|
diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt
|
||||||
index 892c2331..09fdf5fc 100644
|
index ba281b8e6..ead235878 100644
|
||||||
--- a/ggml/src/CMakeLists.txt
|
--- a/ggml/src/CMakeLists.txt
|
||||||
+++ b/ggml/src/CMakeLists.txt
|
+++ b/ggml/src/CMakeLists.txt
|
||||||
@@ -310,6 +310,7 @@ function(ggml_add_cpu_backend_variant tag_name)
|
@@ -314,6 +314,7 @@ function(ggml_add_cpu_backend_variant tag_name)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
ggml_add_cpu_backend_variant_impl(${tag_name})
|
ggml_add_cpu_backend_variant_impl(${tag_name})
|
||||||
|
|
@ -19,7 +19,7 @@ index 892c2331..09fdf5fc 100644
|
||||||
endfunction()
|
endfunction()
|
||||||
|
|
||||||
ggml_add_backend(CPU)
|
ggml_add_backend(CPU)
|
||||||
@@ -320,6 +321,7 @@ if (GGML_CPU_ALL_VARIANTS)
|
@@ -324,6 +325,7 @@ if (GGML_CPU_ALL_VARIANTS)
|
||||||
elseif (GGML_CPU_ARM_ARCH)
|
elseif (GGML_CPU_ARM_ARCH)
|
||||||
message(FATAL_ERROR "Cannot use both GGML_CPU_ARM_ARCH and GGML_CPU_ALL_VARIANTS")
|
message(FATAL_ERROR "Cannot use both GGML_CPU_ARM_ARCH and GGML_CPU_ALL_VARIANTS")
|
||||||
endif()
|
endif()
|
||||||
|
|
|
||||||
|
|
@ -9,10 +9,10 @@ disable amx as it reduces performance on some systems
|
||||||
1 file changed, 4 deletions(-)
|
1 file changed, 4 deletions(-)
|
||||||
|
|
||||||
diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt
|
diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt
|
||||||
index 09fdf5fc..0609c650 100644
|
index ead235878..f9a6587f1 100644
|
||||||
--- a/ggml/src/CMakeLists.txt
|
--- a/ggml/src/CMakeLists.txt
|
||||||
+++ b/ggml/src/CMakeLists.txt
|
+++ b/ggml/src/CMakeLists.txt
|
||||||
@@ -330,10 +330,6 @@ if (GGML_CPU_ALL_VARIANTS)
|
@@ -334,10 +334,6 @@ if (GGML_CPU_ALL_VARIANTS)
|
||||||
ggml_add_cpu_backend_variant(skylakex SSE42 AVX F16C AVX2 BMI2 FMA AVX512)
|
ggml_add_cpu_backend_variant(skylakex SSE42 AVX F16C AVX2 BMI2 FMA AVX512)
|
||||||
ggml_add_cpu_backend_variant(icelake SSE42 AVX F16C AVX2 BMI2 FMA AVX512 AVX512_VBMI AVX512_VNNI)
|
ggml_add_cpu_backend_variant(icelake SSE42 AVX F16C AVX2 BMI2 FMA AVX512 AVX512_VBMI AVX512_VNNI)
|
||||||
ggml_add_cpu_backend_variant(alderlake SSE42 AVX F16C AVX2 BMI2 FMA AVX_VNNI)
|
ggml_add_cpu_backend_variant(alderlake SSE42 AVX F16C AVX2 BMI2 FMA AVX_VNNI)
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ such as vocab fields
|
||||||
3 files changed, 7 insertions(+), 5 deletions(-)
|
3 files changed, 7 insertions(+), 5 deletions(-)
|
||||||
|
|
||||||
diff --git a/ggml/include/gguf.h b/ggml/include/gguf.h
|
diff --git a/ggml/include/gguf.h b/ggml/include/gguf.h
|
||||||
index 79ee2020..3efb22f0 100644
|
index 79ee20206..3efb22f01 100644
|
||||||
--- a/ggml/include/gguf.h
|
--- a/ggml/include/gguf.h
|
||||||
+++ b/ggml/include/gguf.h
|
+++ b/ggml/include/gguf.h
|
||||||
@@ -114,6 +114,7 @@ extern "C" {
|
@@ -114,6 +114,7 @@ extern "C" {
|
||||||
|
|
@ -25,7 +25,7 @@ index 79ee2020..3efb22f0 100644
|
||||||
// get ith C string from array with given key_id
|
// get ith C string from array with given key_id
|
||||||
GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int64_t key_id, size_t i);
|
GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int64_t key_id, size_t i);
|
||||||
diff --git a/ggml/src/gguf.cpp b/ggml/src/gguf.cpp
|
diff --git a/ggml/src/gguf.cpp b/ggml/src/gguf.cpp
|
||||||
index 8cc4ef1c..d950dbdf 100644
|
index 8cc4ef1cf..d950dbdf5 100644
|
||||||
--- a/ggml/src/gguf.cpp
|
--- a/ggml/src/gguf.cpp
|
||||||
+++ b/ggml/src/gguf.cpp
|
+++ b/ggml/src/gguf.cpp
|
||||||
@@ -805,10 +805,14 @@ enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int64_t key_id
|
@@ -805,10 +805,14 @@ enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int64_t key_id
|
||||||
|
|
@ -53,7 +53,7 @@ index 8cc4ef1c..d950dbdf 100644
|
||||||
}
|
}
|
||||||
|
|
||||||
diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp
|
diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp
|
||||||
index 3de95c67..217ede47 100644
|
index 8064dc197..31f49801c 100644
|
||||||
--- a/src/llama-vocab.cpp
|
--- a/src/llama-vocab.cpp
|
||||||
+++ b/src/llama-vocab.cpp
|
+++ b/src/llama-vocab.cpp
|
||||||
@@ -1768,9 +1768,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
@@ -1768,9 +1768,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ Subject: [PATCH] ollama debug tensor
|
||||||
1 file changed, 6 insertions(+)
|
1 file changed, 6 insertions(+)
|
||||||
|
|
||||||
diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c
|
diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c
|
||||||
index ba2a36d9..99509b0c 100644
|
index 9ec485cfa..4b2f8b7bd 100644
|
||||||
--- a/ggml/src/ggml-cpu/ggml-cpu.c
|
--- a/ggml/src/ggml-cpu/ggml-cpu.c
|
||||||
+++ b/ggml/src/ggml-cpu/ggml-cpu.c
|
+++ b/ggml/src/ggml-cpu/ggml-cpu.c
|
||||||
@@ -15,6 +15,8 @@
|
@@ -15,6 +15,8 @@
|
||||||
|
|
@ -20,7 +20,7 @@ index ba2a36d9..99509b0c 100644
|
||||||
#if defined(_MSC_VER) || defined(__MINGW32__)
|
#if defined(_MSC_VER) || defined(__MINGW32__)
|
||||||
#include <malloc.h> // using malloc.h with MSC/MINGW
|
#include <malloc.h> // using malloc.h with MSC/MINGW
|
||||||
#elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__)
|
#elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__)
|
||||||
@@ -2887,6 +2889,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
|
@@ -2891,6 +2893,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
|
||||||
|
|
||||||
ggml_compute_forward(¶ms, node);
|
ggml_compute_forward(¶ms, node);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ Subject: [PATCH] add ollama vocab for grammar support
|
||||||
3 files changed, 58 insertions(+), 9 deletions(-)
|
3 files changed, 58 insertions(+), 9 deletions(-)
|
||||||
|
|
||||||
diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp
|
diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp
|
||||||
index bed706bb..b51cee09 100644
|
index bed706bb2..b51cee090 100644
|
||||||
--- a/src/llama-grammar.cpp
|
--- a/src/llama-grammar.cpp
|
||||||
+++ b/src/llama-grammar.cpp
|
+++ b/src/llama-grammar.cpp
|
||||||
@@ -907,6 +907,7 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
|
@@ -907,6 +907,7 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
|
||||||
|
|
@ -137,7 +137,7 @@ index bed706bb..b51cee09 100644
|
||||||
+ }
|
+ }
|
||||||
+}
|
+}
|
||||||
diff --git a/src/llama-grammar.h b/src/llama-grammar.h
|
diff --git a/src/llama-grammar.h b/src/llama-grammar.h
|
||||||
index f8c291de..2a3a62db 100644
|
index f8c291de9..2a3a62db3 100644
|
||||||
--- a/src/llama-grammar.h
|
--- a/src/llama-grammar.h
|
||||||
+++ b/src/llama-grammar.h
|
+++ b/src/llama-grammar.h
|
||||||
@@ -6,8 +6,19 @@
|
@@ -6,8 +6,19 @@
|
||||||
|
|
@ -184,7 +184,7 @@ index f8c291de..2a3a62db 100644
|
||||||
const char * grammar_root,
|
const char * grammar_root,
|
||||||
bool lazy,
|
bool lazy,
|
||||||
diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp
|
diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp
|
||||||
index 55d2e355..da34526b 100644
|
index 55d2e355f..da34526b1 100644
|
||||||
--- a/src/llama-sampling.cpp
|
--- a/src/llama-sampling.cpp
|
||||||
+++ b/src/llama-sampling.cpp
|
+++ b/src/llama-sampling.cpp
|
||||||
@@ -1563,7 +1563,7 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
|
@@ -1563,7 +1563,7 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
|
||||||
|
|
|
||||||
|
|
@ -4,15 +4,15 @@ Date: Thu, 1 May 2025 13:45:12 -0700
|
||||||
Subject: [PATCH] add argsort and cuda copy for i32
|
Subject: [PATCH] add argsort and cuda copy for i32
|
||||||
|
|
||||||
---
|
---
|
||||||
ggml/src/ggml-cpu/ops.cpp | 43 +++++++++++
|
ggml/src/ggml-cpu/ops.cpp | 43 ++++++++++
|
||||||
ggml/src/ggml-cuda/argsort.cu | 102 ++++++++++++++++++++++++++-
|
ggml/src/ggml-cuda/argsort.cu | 122 ++++++++++++++++++++++++---
|
||||||
ggml/src/ggml-cuda/cpy-utils.cuh | 6 ++
|
ggml/src/ggml-cuda/cpy-utils.cuh | 6 ++
|
||||||
ggml/src/ggml-cuda/cpy.cu | 43 +++++++++++
|
ggml/src/ggml-cuda/cpy.cu | 40 +++++++++
|
||||||
ggml/src/ggml-metal/ggml-metal.metal | 64 +++++++++++++++++
|
ggml/src/ggml-metal/ggml-metal.metal | 64 ++++++++++++++
|
||||||
5 files changed, 256 insertions(+), 2 deletions(-)
|
5 files changed, 263 insertions(+), 12 deletions(-)
|
||||||
|
|
||||||
diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp
|
diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp
|
||||||
index 1c43865f..31478dd8 100644
|
index b52f0f847..902fdad69 100644
|
||||||
--- a/ggml/src/ggml-cpu/ops.cpp
|
--- a/ggml/src/ggml-cpu/ops.cpp
|
||||||
+++ b/ggml/src/ggml-cpu/ops.cpp
|
+++ b/ggml/src/ggml-cpu/ops.cpp
|
||||||
@@ -7889,6 +7889,45 @@ static void ggml_compute_forward_argsort_f32(
|
@@ -7889,6 +7889,45 @@ static void ggml_compute_forward_argsort_f32(
|
||||||
|
|
@ -73,10 +73,10 @@ index 1c43865f..31478dd8 100644
|
||||||
{
|
{
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
diff --git a/ggml/src/ggml-cuda/argsort.cu b/ggml/src/ggml-cuda/argsort.cu
|
diff --git a/ggml/src/ggml-cuda/argsort.cu b/ggml/src/ggml-cuda/argsort.cu
|
||||||
index 607ded85..53b02634 100644
|
index 6e7b90d42..08dd30525 100644
|
||||||
--- a/ggml/src/ggml-cuda/argsort.cu
|
--- a/ggml/src/ggml-cuda/argsort.cu
|
||||||
+++ b/ggml/src/ggml-cuda/argsort.cu
|
+++ b/ggml/src/ggml-cuda/argsort.cu
|
||||||
@@ -85,13 +85,107 @@ static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, co
|
@@ -168,13 +168,107 @@ static void argsort_f32_i32_cuda_bitonic(const float * x,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -185,19 +185,42 @@ index 607ded85..53b02634 100644
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_I32);
|
GGML_ASSERT( dst->type == GGML_TYPE_I32);
|
||||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||||
|
|
||||||
@@ -100,5 +194,9 @@ void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
@@ -183,18 +277,22 @@ void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
|
||||||
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
|
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
|
||||||
|
|
||||||
- argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, order, stream);
|
-#ifdef GGML_CUDA_USE_CUB
|
||||||
|
- const int ncols_pad = next_power_of_2(ncols);
|
||||||
|
- const size_t shared_mem = ncols_pad * sizeof(int);
|
||||||
|
- const size_t max_shared_mem = ggml_cuda_info().devices[ggml_cuda_get_device()].smpb;
|
||||||
|
-
|
||||||
|
- if (shared_mem > max_shared_mem || ncols > 1024) {
|
||||||
|
- ggml_cuda_pool & pool = ctx.pool();
|
||||||
|
- argsort_f32_i32_cuda_cub(pool, src0_d, (int *) dst_d, ncols, nrows, order, stream);
|
||||||
+ if (src0->type == GGML_TYPE_I32) {
|
+ if (src0->type == GGML_TYPE_I32) {
|
||||||
+ argsort_i32_i32_cuda((const int32_t *)src0_d, (int *)dst_d, ncols, nrows, order, stream);
|
+ argsort_i32_i32_cuda((const int32_t *)src0_d, (int *)dst_d, ncols, nrows, order, stream);
|
||||||
|
} else {
|
||||||
|
- argsort_f32_i32_cuda_bitonic(src0_d, (int *) dst_d, ncols, nrows, order, stream);
|
||||||
|
- }
|
||||||
|
+#ifdef GGML_CUDA_USE_CUB
|
||||||
|
+ const int ncols_pad = next_power_of_2(ncols);
|
||||||
|
+ const size_t shared_mem = ncols_pad * sizeof(int);
|
||||||
|
+ const size_t max_shared_mem = ggml_cuda_info().devices[ggml_cuda_get_device()].smpb;
|
||||||
|
+
|
||||||
|
+ if (shared_mem > max_shared_mem || ncols > 1024) {
|
||||||
|
+ ggml_cuda_pool & pool = ctx.pool();
|
||||||
|
+ argsort_f32_i32_cuda_cub(pool, src0_d, (int *) dst_d, ncols, nrows, order, stream);
|
||||||
+ } else {
|
+ } else {
|
||||||
+ argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, order, stream);
|
+ argsort_f32_i32_cuda_bitonic(src0_d, (int *) dst_d, ncols, nrows, order, stream);
|
||||||
|
+ }
|
||||||
|
#else
|
||||||
|
- argsort_f32_i32_cuda_bitonic(src0_d, (int *) dst_d, ncols, nrows, order, stream);
|
||||||
|
+ argsort_f32_i32_cuda_bitonic(src0_d, (int *) dst_d, ncols, nrows, order, stream);
|
||||||
|
#endif
|
||||||
+ }
|
+ }
|
||||||
}
|
}
|
||||||
diff --git a/ggml/src/ggml-cuda/cpy-utils.cuh b/ggml/src/ggml-cuda/cpy-utils.cuh
|
diff --git a/ggml/src/ggml-cuda/cpy-utils.cuh b/ggml/src/ggml-cuda/cpy-utils.cuh
|
||||||
index e621cb98..597c0c8b 100644
|
index e621cb981..597c0c8b3 100644
|
||||||
--- a/ggml/src/ggml-cuda/cpy-utils.cuh
|
--- a/ggml/src/ggml-cuda/cpy-utils.cuh
|
||||||
+++ b/ggml/src/ggml-cuda/cpy-utils.cuh
|
+++ b/ggml/src/ggml-cuda/cpy-utils.cuh
|
||||||
@@ -215,3 +215,9 @@ template<typename src_t, typename dst_t>
|
@@ -215,3 +215,9 @@ template<typename src_t, typename dst_t>
|
||||||
|
|
@ -211,19 +234,18 @@ index e621cb98..597c0c8b 100644
|
||||||
+ *dst = *src;
|
+ *dst = *src;
|
||||||
+}
|
+}
|
||||||
diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu
|
diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu
|
||||||
index 746f4396..911220e9 100644
|
index 12d5bf776..a0e34030e 100644
|
||||||
--- a/ggml/src/ggml-cuda/cpy.cu
|
--- a/ggml/src/ggml-cuda/cpy.cu
|
||||||
+++ b/ggml/src/ggml-cuda/cpy.cu
|
+++ b/ggml/src/ggml-cuda/cpy.cu
|
||||||
@@ -277,6 +277,47 @@ static void ggml_cpy_f32_iq4_nl_cuda(
|
@@ -251,6 +251,43 @@ static void ggml_cpy_f32_iq4_nl_cuda(
|
||||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||||
}
|
}
|
||||||
|
|
||||||
+template <cpy_kernel_t cpy_1>
|
+template <cpy_kernel_t cpy_1>
|
||||||
+static __global__ void cpy_i32_i32(
|
+static __global__ void cpy_i32_i32(
|
||||||
+ const char *cx, char *cdst, const int ne,
|
+ const char *cx, char *cdst, const int ne,
|
||||||
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
|
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
|
||||||
+ const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13,
|
+ const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||||
+ cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
|
||||||
+
|
+
|
||||||
+ const int64_t i = blockDim.x * blockIdx.x + threadIdx.x;
|
+ const int64_t i = blockDim.x * blockIdx.x + threadIdx.x;
|
||||||
+
|
+
|
||||||
|
|
@ -243,39 +265,37 @@ index 746f4396..911220e9 100644
|
||||||
+ const int64_t i10 = i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11 - i11 * ne10;
|
+ const int64_t i10 = i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11 - i11 * ne10;
|
||||||
+ const int64_t dst_offset = i10 * nb10 + i11 * nb11 + i12 * nb12 + i13 * nb13;
|
+ const int64_t dst_offset = i10 * nb10 + i11 * nb11 + i12 * nb12 + i13 * nb13;
|
||||||
+
|
+
|
||||||
+ char * cdst_ptr = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index] : cdst;
|
+ cpy_1(cx + x_offset, cdst + dst_offset);
|
||||||
+ cpy_1(cx + x_offset, cdst_ptr + dst_offset);
|
|
||||||
+}
|
+}
|
||||||
+
|
+
|
||||||
+
|
|
||||||
+static void ggml_cpy_i32_i32_cuda(
|
+static void ggml_cpy_i32_i32_cuda(
|
||||||
+ const char * cx, char * cdst, const int ne,
|
+ const char * cx, char * cdst, const int ne,
|
||||||
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
|
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
|
||||||
+ const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13,
|
+ const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||||
+ cudaStream_t stream, char ** cdst_indirect, int graph_cpynode_index) {
|
|
||||||
+
|
+
|
||||||
+ const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
+ const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
||||||
+ cpy_i32_i32<cpy_1_i32_i32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
+ cpy_i32_i32<cpy_1_i32_i32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
||||||
+ (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, stream, cdst_indirect, graph_cpynode_index);
|
+ (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, stream);
|
||||||
+}
|
+}
|
||||||
+
|
+
|
||||||
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection_for_this_node) {
|
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) {
|
||||||
const int64_t ne = ggml_nelements(src0);
|
const int64_t ne = ggml_nelements(src0);
|
||||||
GGML_ASSERT(ne == ggml_nelements(src1));
|
GGML_ASSERT(ne == ggml_nelements(src1));
|
||||||
@@ -372,6 +413,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
@@ -332,6 +369,9 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
||||||
ggml_cpy_flt_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
ggml_cpy_flt_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
||||||
ggml_cpy_flt_cuda<half, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
ggml_cpy_flt_cuda<half, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
+ } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) {
|
+ } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) {
|
||||||
+ ggml_cpy_i32_i32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
+ // TODO consider converting to template
|
||||||
|
+ ggml_cpy_i32_i32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
|
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
|
||||||
ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
|
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
|
||||||
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
|
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
|
||||||
index 74a9aa99..375a0c7f 100644
|
index 2c2f01415..50b8071de 100644
|
||||||
--- a/ggml/src/ggml-metal/ggml-metal.metal
|
--- a/ggml/src/ggml-metal/ggml-metal.metal
|
||||||
+++ b/ggml/src/ggml-metal/ggml-metal.metal
|
+++ b/ggml/src/ggml-metal/ggml-metal.metal
|
||||||
@@ -4346,8 +4346,72 @@ kernel void kernel_argsort_f32_i32(
|
@@ -4467,8 +4467,72 @@ kernel void kernel_argsort_f32_i32(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ Subject: [PATCH] graph memory reporting on failure
|
||||||
4 files changed, 40 insertions(+), 3 deletions(-)
|
4 files changed, 40 insertions(+), 3 deletions(-)
|
||||||
|
|
||||||
diff --git a/ggml/include/ggml-alloc.h b/ggml/include/ggml-alloc.h
|
diff --git a/ggml/include/ggml-alloc.h b/ggml/include/ggml-alloc.h
|
||||||
index 2cb150fd..7ab3f019 100644
|
index 2cb150fd2..7ab3f0192 100644
|
||||||
--- a/ggml/include/ggml-alloc.h
|
--- a/ggml/include/ggml-alloc.h
|
||||||
+++ b/ggml/include/ggml-alloc.h
|
+++ b/ggml/include/ggml-alloc.h
|
||||||
@@ -65,6 +65,7 @@ GGML_API bool ggml_gallocr_reserve_n(
|
@@ -65,6 +65,7 @@ GGML_API bool ggml_gallocr_reserve_n(
|
||||||
|
|
@ -23,7 +23,7 @@ index 2cb150fd..7ab3f019 100644
|
||||||
// Utils
|
// Utils
|
||||||
// Create a buffer and allocate all the tensors in a ggml_context
|
// Create a buffer and allocate all the tensors in a ggml_context
|
||||||
diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h
|
diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h
|
||||||
index f1b74078..c54ff98b 100644
|
index f1b740785..c54ff98bf 100644
|
||||||
--- a/ggml/include/ggml-backend.h
|
--- a/ggml/include/ggml-backend.h
|
||||||
+++ b/ggml/include/ggml-backend.h
|
+++ b/ggml/include/ggml-backend.h
|
||||||
@@ -318,6 +318,7 @@ extern "C" {
|
@@ -318,6 +318,7 @@ extern "C" {
|
||||||
|
|
@ -35,7 +35,7 @@ index f1b74078..c54ff98b 100644
|
||||||
GGML_API void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend);
|
GGML_API void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend);
|
||||||
GGML_API ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node);
|
GGML_API ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node);
|
||||||
diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c
|
diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c
|
||||||
index 929bc448..eee9d3b1 100644
|
index c830c0965..363853873 100644
|
||||||
--- a/ggml/src/ggml-alloc.c
|
--- a/ggml/src/ggml-alloc.c
|
||||||
+++ b/ggml/src/ggml-alloc.c
|
+++ b/ggml/src/ggml-alloc.c
|
||||||
@@ -486,6 +486,7 @@ struct node_alloc {
|
@@ -486,6 +486,7 @@ struct node_alloc {
|
||||||
|
|
@ -64,7 +64,7 @@ index 929bc448..eee9d3b1 100644
|
||||||
free(galloc->buffers);
|
free(galloc->buffers);
|
||||||
free(galloc->buf_tallocs);
|
free(galloc->buf_tallocs);
|
||||||
free(galloc->node_allocs);
|
free(galloc->node_allocs);
|
||||||
@@ -869,6 +874,8 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c
|
@@ -891,6 +896,8 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -73,7 +73,7 @@ index 929bc448..eee9d3b1 100644
|
||||||
// reallocate buffers if needed
|
// reallocate buffers if needed
|
||||||
for (int i = 0; i < galloc->n_buffers; i++) {
|
for (int i = 0; i < galloc->n_buffers; i++) {
|
||||||
// if the buffer type is used multiple times, we reuse the same buffer
|
// if the buffer type is used multiple times, we reuse the same buffer
|
||||||
@@ -898,14 +905,19 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c
|
@@ -920,14 +927,19 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c
|
||||||
|
|
||||||
ggml_vbuffer_free(galloc->buffers[i]);
|
ggml_vbuffer_free(galloc->buffers[i]);
|
||||||
galloc->buffers[i] = ggml_vbuffer_alloc(galloc->bufts[i], galloc->buf_tallocs[i], GGML_BACKEND_BUFFER_USAGE_COMPUTE);
|
galloc->buffers[i] = ggml_vbuffer_alloc(galloc->bufts[i], galloc->buf_tallocs[i], GGML_BACKEND_BUFFER_USAGE_COMPUTE);
|
||||||
|
|
@ -96,7 +96,7 @@ index 929bc448..eee9d3b1 100644
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ggml_gallocr_reserve(ggml_gallocr_t galloc, struct ggml_cgraph *graph) {
|
bool ggml_gallocr_reserve(ggml_gallocr_t galloc, struct ggml_cgraph *graph) {
|
||||||
@@ -1060,6 +1072,22 @@ size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_id) {
|
@@ -1082,6 +1094,22 @@ size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_id) {
|
||||||
return ggml_vbuffer_size(galloc->buffers[buffer_id]);
|
return ggml_vbuffer_size(galloc->buffers[buffer_id]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -120,7 +120,7 @@ index 929bc448..eee9d3b1 100644
|
||||||
|
|
||||||
static void free_buffers(ggml_backend_buffer_t ** buffers, const size_t * n_buffers) {
|
static void free_buffers(ggml_backend_buffer_t ** buffers, const size_t * n_buffers) {
|
||||||
diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp
|
diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp
|
||||||
index 8ba86f82..cb2b9956 100644
|
index 8ba86f824..cb2b99562 100644
|
||||||
--- a/ggml/src/ggml-backend.cpp
|
--- a/ggml/src/ggml-backend.cpp
|
||||||
+++ b/ggml/src/ggml-backend.cpp
|
+++ b/ggml/src/ggml-backend.cpp
|
||||||
@@ -1809,6 +1809,13 @@ size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backe
|
@@ -1809,6 +1809,13 @@ size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backe
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ with tools (e.g. nvidia-smi) and system management libraries (e.g. nvml).
|
||||||
3 files changed, 63 insertions(+), 6 deletions(-)
|
3 files changed, 63 insertions(+), 6 deletions(-)
|
||||||
|
|
||||||
diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h
|
diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h
|
||||||
index c54ff98b..229bf387 100644
|
index c54ff98bf..229bf387b 100644
|
||||||
--- a/ggml/include/ggml-backend.h
|
--- a/ggml/include/ggml-backend.h
|
||||||
+++ b/ggml/include/ggml-backend.h
|
+++ b/ggml/include/ggml-backend.h
|
||||||
@@ -158,6 +158,7 @@ extern "C" {
|
@@ -158,6 +158,7 @@ extern "C" {
|
||||||
|
|
@ -24,7 +24,7 @@ index c54ff98b..229bf387 100644
|
||||||
size_t memory_total;
|
size_t memory_total;
|
||||||
// device type
|
// device type
|
||||||
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
|
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||||
index c0b1e4c1..5b852f69 100644
|
index aefc6935e..cc201afff 100644
|
||||||
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
|
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||||
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
|
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||||
@@ -183,6 +183,51 @@ static int ggml_cuda_parse_id(char devName[]) {
|
@@ -183,6 +183,51 @@ static int ggml_cuda_parse_id(char devName[]) {
|
||||||
|
|
@ -110,7 +110,7 @@ index c0b1e4c1..5b852f69 100644
|
||||||
std::string device_name(prop.name);
|
std::string device_name(prop.name);
|
||||||
if (device_name == "NVIDIA GeForce MX450") {
|
if (device_name == "NVIDIA GeForce MX450") {
|
||||||
turing_devices_without_mma.push_back({ id, device_name });
|
turing_devices_without_mma.push_back({ id, device_name });
|
||||||
@@ -3276,6 +3323,7 @@ struct ggml_backend_cuda_device_context {
|
@@ -3268,6 +3315,7 @@ struct ggml_backend_cuda_device_context {
|
||||||
std::string name;
|
std::string name;
|
||||||
std::string description;
|
std::string description;
|
||||||
std::string pci_bus_id;
|
std::string pci_bus_id;
|
||||||
|
|
@ -118,7 +118,7 @@ index c0b1e4c1..5b852f69 100644
|
||||||
};
|
};
|
||||||
|
|
||||||
static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) {
|
static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) {
|
||||||
@@ -3288,6 +3336,11 @@ static const char * ggml_backend_cuda_device_get_description(ggml_backend_dev_t
|
@@ -3280,6 +3328,11 @@ static const char * ggml_backend_cuda_device_get_description(ggml_backend_dev_t
|
||||||
return ctx->description.c_str();
|
return ctx->description.c_str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -130,7 +130,7 @@ index c0b1e4c1..5b852f69 100644
|
||||||
static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
|
static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
|
||||||
ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
|
ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
|
||||||
ggml_cuda_set_device(ctx->device);
|
ggml_cuda_set_device(ctx->device);
|
||||||
@@ -3304,6 +3357,7 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back
|
@@ -3296,6 +3349,7 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back
|
||||||
|
|
||||||
props->name = ggml_backend_cuda_device_get_name(dev);
|
props->name = ggml_backend_cuda_device_get_name(dev);
|
||||||
props->description = ggml_backend_cuda_device_get_description(dev);
|
props->description = ggml_backend_cuda_device_get_description(dev);
|
||||||
|
|
@ -138,7 +138,7 @@ index c0b1e4c1..5b852f69 100644
|
||||||
props->type = ggml_backend_cuda_device_get_type(dev);
|
props->type = ggml_backend_cuda_device_get_type(dev);
|
||||||
props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str();
|
props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str();
|
||||||
ggml_backend_cuda_device_get_memory(dev, &props->memory_free, &props->memory_total);
|
ggml_backend_cuda_device_get_memory(dev, &props->memory_free, &props->memory_total);
|
||||||
@@ -3873,6 +3927,7 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
|
@@ -3869,6 +3923,7 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
|
||||||
cudaDeviceProp prop;
|
cudaDeviceProp prop;
|
||||||
CUDA_CHECK(cudaGetDeviceProperties(&prop, i));
|
CUDA_CHECK(cudaGetDeviceProperties(&prop, i));
|
||||||
dev_ctx->description = prop.name;
|
dev_ctx->description = prop.name;
|
||||||
|
|
@ -147,7 +147,7 @@ index c0b1e4c1..5b852f69 100644
|
||||||
char pci_bus_id[16] = {};
|
char pci_bus_id[16] = {};
|
||||||
snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.0", prop.pciDomainID, prop.pciBusID, prop.pciDeviceID);
|
snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.0", prop.pciDomainID, prop.pciBusID, prop.pciDeviceID);
|
||||||
diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp
|
diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp
|
||||||
index bf096227..f2ff9f32 100644
|
index bf0962274..f2ff9f322 100644
|
||||||
--- a/ggml/src/ggml-metal/ggml-metal.cpp
|
--- a/ggml/src/ggml-metal/ggml-metal.cpp
|
||||||
+++ b/ggml/src/ggml-metal/ggml-metal.cpp
|
+++ b/ggml/src/ggml-metal/ggml-metal.cpp
|
||||||
@@ -538,6 +538,7 @@ static enum ggml_backend_dev_type ggml_backend_metal_device_get_type(ggml_backen
|
@@ -538,6 +538,7 @@ static enum ggml_backend_dev_type ggml_backend_metal_device_get_type(ggml_backen
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
|
||||||
2 files changed, 13 insertions(+)
|
2 files changed, 13 insertions(+)
|
||||||
|
|
||||||
diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp
|
diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp
|
||||||
index 4d487581..35a0d25e 100644
|
index 4d487581a..35a0d25ed 100644
|
||||||
--- a/tools/mtmd/mtmd.cpp
|
--- a/tools/mtmd/mtmd.cpp
|
||||||
+++ b/tools/mtmd/mtmd.cpp
|
+++ b/tools/mtmd/mtmd.cpp
|
||||||
@@ -79,6 +79,16 @@ enum mtmd_slice_tmpl {
|
@@ -79,6 +79,16 @@ enum mtmd_slice_tmpl {
|
||||||
|
|
@ -31,7 +31,7 @@ index 4d487581..35a0d25e 100644
|
||||||
return "<__media__>";
|
return "<__media__>";
|
||||||
}
|
}
|
||||||
diff --git a/tools/mtmd/mtmd.h b/tools/mtmd/mtmd.h
|
diff --git a/tools/mtmd/mtmd.h b/tools/mtmd/mtmd.h
|
||||||
index f4ea07d3..cf287224 100644
|
index f4ea07d3a..cf287224b 100644
|
||||||
--- a/tools/mtmd/mtmd.h
|
--- a/tools/mtmd/mtmd.h
|
||||||
+++ b/tools/mtmd/mtmd.h
|
+++ b/tools/mtmd/mtmd.h
|
||||||
@@ -75,6 +75,9 @@ typedef struct mtmd_input_chunk mtmd_input_chunk;
|
@@ -75,6 +75,9 @@ typedef struct mtmd_input_chunk mtmd_input_chunk;
|
||||||
|
|
|
||||||
|
|
@ -8,10 +8,10 @@ Subject: [PATCH] no power throttling win32 with gnuc
|
||||||
1 file changed, 1 insertion(+), 1 deletion(-)
|
1 file changed, 1 insertion(+), 1 deletion(-)
|
||||||
|
|
||||||
diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c
|
diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c
|
||||||
index 99509b0c..b13a491d 100644
|
index 4b2f8b7bd..046646282 100644
|
||||||
--- a/ggml/src/ggml-cpu/ggml-cpu.c
|
--- a/ggml/src/ggml-cpu/ggml-cpu.c
|
||||||
+++ b/ggml/src/ggml-cpu/ggml-cpu.c
|
+++ b/ggml/src/ggml-cpu/ggml-cpu.c
|
||||||
@@ -2437,7 +2437,7 @@ static bool ggml_thread_apply_priority(int32_t prio) {
|
@@ -2441,7 +2441,7 @@ static bool ggml_thread_apply_priority(int32_t prio) {
|
||||||
// Newer Windows 11 versions aggresively park (offline) CPU cores and often place
|
// Newer Windows 11 versions aggresively park (offline) CPU cores and often place
|
||||||
// all our threads onto the first 4 cores which results in terrible performance with
|
// all our threads onto the first 4 cores which results in terrible performance with
|
||||||
// n_threads > 4
|
// n_threads > 4
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ Only enable BF16 on supported MacOS versions (v14+)
|
||||||
1 file changed, 6 insertions(+), 1 deletion(-)
|
1 file changed, 6 insertions(+), 1 deletion(-)
|
||||||
|
|
||||||
diff --git a/ggml/src/ggml-metal/ggml-metal-context.m b/ggml/src/ggml-metal/ggml-metal-context.m
|
diff --git a/ggml/src/ggml-metal/ggml-metal-context.m b/ggml/src/ggml-metal/ggml-metal-context.m
|
||||||
index 052efb7a..b47dc787 100644
|
index 052efb7ac..b47dc7879 100644
|
||||||
--- a/ggml/src/ggml-metal/ggml-metal-context.m
|
--- a/ggml/src/ggml-metal/ggml-metal-context.m
|
||||||
+++ b/ggml/src/ggml-metal/ggml-metal-context.m
|
+++ b/ggml/src/ggml-metal/ggml-metal-context.m
|
||||||
@@ -125,7 +125,12 @@ ggml_metal_t ggml_metal_init(ggml_metal_device_t dev) {
|
@@ -125,7 +125,12 @@ ggml_metal_t ggml_metal_init(ggml_metal_device_t dev) {
|
||||||
|
|
|
||||||
|
|
@ -178,19 +178,19 @@ index 3191faaa4..32f14c811 100644
|
||||||
|
|
||||||
static const struct ggml_backend_i ggml_backend_cpu_i = {
|
static const struct ggml_backend_i ggml_backend_cpu_i = {
|
||||||
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
|
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||||
index 5b852f690..c555cd30f 100644
|
index cc201afff..02d413467 100644
|
||||||
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
|
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||||
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
|
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||||
@@ -2684,7 +2684,7 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
|
@@ -2693,7 +2693,7 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
|
||||||
|
|
||||||
#ifdef USE_CUDA_GRAPH
|
#ifdef USE_CUDA_GRAPH
|
||||||
static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
|
static bool check_node_graph_compatibility(ggml_cgraph * cgraph,
|
||||||
- bool use_cuda_graph) {
|
- bool use_cuda_graph) {
|
||||||
+ int batch_size, bool use_cuda_graph) {
|
+ int batch_size, bool use_cuda_graph) {
|
||||||
|
|
||||||
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
|
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
|
||||||
cuda_ctx->cuda_graph->cpy_dest_ptrs.clear();
|
|
||||||
@@ -2718,24 +2718,34 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
|
@@ -2726,24 +2726,34 @@ static bool check_node_graph_compatibility(ggml_cgraph * cgraph,
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -240,8 +240,8 @@ index 5b852f690..c555cd30f 100644
|
||||||
+ }
|
+ }
|
||||||
}
|
}
|
||||||
|
|
||||||
if (node->op == GGML_OP_CPY) {
|
if (!use_cuda_graph) {
|
||||||
@@ -3132,7 +3142,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
|
@@ -3128,7 +3138,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -250,12 +250,12 @@ index 5b852f690..c555cd30f 100644
|
||||||
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
|
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
|
||||||
|
|
||||||
ggml_cuda_set_device(cuda_ctx->device);
|
ggml_cuda_set_device(cuda_ctx->device);
|
||||||
@@ -3170,7 +3180,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
|
@@ -3166,7 +3176,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
|
||||||
if (use_cuda_graph) {
|
if (use_cuda_graph) {
|
||||||
cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph);
|
cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph);
|
||||||
|
|
||||||
- use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops(cuda_ctx, cgraph, use_cuda_graph);
|
- use_cuda_graph = check_node_graph_compatibility(cgraph, use_cuda_graph);
|
||||||
+ use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops(cuda_ctx, cgraph, batch_size, use_cuda_graph);
|
+ use_cuda_graph = check_node_graph_compatibility(cgraph, batch_size, use_cuda_graph);
|
||||||
|
|
||||||
// Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
|
// Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
|
||||||
if (use_cuda_graph && cuda_graph_update_required) {
|
if (use_cuda_graph && cuda_graph_update_required) {
|
||||||
|
|
@ -278,10 +278,10 @@ index f2ff9f322..05ff6a5a6 100644
|
||||||
|
|
||||||
static void ggml_backend_metal_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
static void ggml_backend_metal_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
||||||
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||||
index ed83236f4..bd3ece516 100644
|
index 216dc167c..3a6bbe564 100644
|
||||||
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||||
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||||
@@ -12015,7 +12015,7 @@ static uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const stru
|
@@ -12357,7 +12357,7 @@ static uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const stru
|
||||||
return num_adds;
|
return num_adds;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -290,7 +290,7 @@ index ed83236f4..bd3ece516 100644
|
||||||
VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)");
|
VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)");
|
||||||
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
|
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
|
||||||
|
|
||||||
@@ -12211,6 +12211,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
@@ -12561,6 +12561,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
||||||
return GGML_STATUS_SUCCESS;
|
return GGML_STATUS_SUCCESS;
|
||||||
|
|
||||||
UNUSED(backend);
|
UNUSED(backend);
|
||||||
|
|
|
||||||
|
|
@ -8,10 +8,10 @@ Subject: [PATCH] Disable ggml-blas on macos v13 and older
|
||||||
1 file changed, 5 insertions(+)
|
1 file changed, 5 insertions(+)
|
||||||
|
|
||||||
diff --git a/ggml/src/ggml-blas/ggml-blas.cpp b/ggml/src/ggml-blas/ggml-blas.cpp
|
diff --git a/ggml/src/ggml-blas/ggml-blas.cpp b/ggml/src/ggml-blas/ggml-blas.cpp
|
||||||
index 5b888cdd..2a9ff7f6 100644
|
index 88d088952..6a38a51a2 100644
|
||||||
--- a/ggml/src/ggml-blas/ggml-blas.cpp
|
--- a/ggml/src/ggml-blas/ggml-blas.cpp
|
||||||
+++ b/ggml/src/ggml-blas/ggml-blas.cpp
|
+++ b/ggml/src/ggml-blas/ggml-blas.cpp
|
||||||
@@ -506,6 +506,11 @@ static const struct ggml_backend_reg_i ggml_backend_blas_reg_i = {
|
@@ -507,6 +507,11 @@ static const struct ggml_backend_reg_i ggml_backend_blas_reg_i = {
|
||||||
};
|
};
|
||||||
|
|
||||||
ggml_backend_reg_t ggml_backend_blas_reg(void) {
|
ggml_backend_reg_t ggml_backend_blas_reg(void) {
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ Subject: [PATCH] fix mtmd-audio.cpp build on windows
|
||||||
1 file changed, 1 insertion(+), 1 deletion(-)
|
1 file changed, 1 insertion(+), 1 deletion(-)
|
||||||
|
|
||||||
diff --git a/tools/mtmd/mtmd-audio.cpp b/tools/mtmd/mtmd-audio.cpp
|
diff --git a/tools/mtmd/mtmd-audio.cpp b/tools/mtmd/mtmd-audio.cpp
|
||||||
index 4d053895..84bdc277 100644
|
index 4d053895c..84bdc2777 100644
|
||||||
--- a/tools/mtmd/mtmd-audio.cpp
|
--- a/tools/mtmd/mtmd-audio.cpp
|
||||||
+++ b/tools/mtmd/mtmd-audio.cpp
|
+++ b/tools/mtmd/mtmd-audio.cpp
|
||||||
@@ -1,6 +1,6 @@
|
@@ -1,6 +1,6 @@
|
||||||
|
|
|
||||||
|
|
@ -219,7 +219,7 @@ index 41eef3b5f..c81a2e48a 100644
|
||||||
|
|
||||||
void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend) {
|
void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend) {
|
||||||
diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh
|
diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh
|
||||||
index e0abde542..e98044bd8 100644
|
index 41ff89c4d..2931c15ca 100644
|
||||||
--- a/ggml/src/ggml-cuda/common.cuh
|
--- a/ggml/src/ggml-cuda/common.cuh
|
||||||
+++ b/ggml/src/ggml-cuda/common.cuh
|
+++ b/ggml/src/ggml-cuda/common.cuh
|
||||||
@@ -35,6 +35,41 @@
|
@@ -35,6 +35,41 @@
|
||||||
|
|
@ -274,7 +274,7 @@ index e0abde542..e98044bd8 100644
|
||||||
};
|
};
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
@@ -999,11 +1037,11 @@ struct ggml_backend_cuda_context {
|
@@ -992,11 +1030,11 @@ struct ggml_backend_cuda_context {
|
||||||
// pool
|
// pool
|
||||||
std::unique_ptr<ggml_cuda_pool> pools[GGML_CUDA_MAX_DEVICES];
|
std::unique_ptr<ggml_cuda_pool> pools[GGML_CUDA_MAX_DEVICES];
|
||||||
|
|
||||||
|
|
@ -288,7 +288,7 @@ index e0abde542..e98044bd8 100644
|
||||||
}
|
}
|
||||||
return *pools[device];
|
return *pools[device];
|
||||||
}
|
}
|
||||||
@@ -1011,4 +1049,20 @@ struct ggml_backend_cuda_context {
|
@@ -1004,4 +1042,20 @@ struct ggml_backend_cuda_context {
|
||||||
ggml_cuda_pool & pool() {
|
ggml_cuda_pool & pool() {
|
||||||
return pool(device);
|
return pool(device);
|
||||||
}
|
}
|
||||||
|
|
@ -310,10 +310,10 @@ index e0abde542..e98044bd8 100644
|
||||||
+ }
|
+ }
|
||||||
};
|
};
|
||||||
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
|
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||||
index c555cd30f..eb3db0f19 100644
|
index 02d413467..f79e5d65c 100644
|
||||||
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
|
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||||
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
|
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||||
@@ -350,6 +350,8 @@ const ggml_cuda_device_info & ggml_cuda_info() {
|
@@ -359,6 +359,8 @@ const ggml_cuda_device_info & ggml_cuda_info() {
|
||||||
|
|
||||||
// #define DEBUG_CUDA_MALLOC
|
// #define DEBUG_CUDA_MALLOC
|
||||||
|
|
||||||
|
|
@ -322,7 +322,7 @@ index c555cd30f..eb3db0f19 100644
|
||||||
// buffer pool for cuda (legacy)
|
// buffer pool for cuda (legacy)
|
||||||
struct ggml_cuda_pool_leg : public ggml_cuda_pool {
|
struct ggml_cuda_pool_leg : public ggml_cuda_pool {
|
||||||
static const int MAX_BUFFERS = 256;
|
static const int MAX_BUFFERS = 256;
|
||||||
@@ -362,9 +364,12 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
|
@@ -371,9 +373,12 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
|
||||||
|
|
||||||
ggml_cuda_buffer buffer_pool[MAX_BUFFERS] = {};
|
ggml_cuda_buffer buffer_pool[MAX_BUFFERS] = {};
|
||||||
size_t pool_size = 0;
|
size_t pool_size = 0;
|
||||||
|
|
@ -337,7 +337,7 @@ index c555cd30f..eb3db0f19 100644
|
||||||
}
|
}
|
||||||
|
|
||||||
~ggml_cuda_pool_leg() {
|
~ggml_cuda_pool_leg() {
|
||||||
@@ -372,7 +377,9 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
|
@@ -381,7 +386,9 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
|
||||||
for (int i = 0; i < MAX_BUFFERS; ++i) {
|
for (int i = 0; i < MAX_BUFFERS; ++i) {
|
||||||
ggml_cuda_buffer & b = buffer_pool[i];
|
ggml_cuda_buffer & b = buffer_pool[i];
|
||||||
if (b.ptr != nullptr) {
|
if (b.ptr != nullptr) {
|
||||||
|
|
@ -348,7 +348,7 @@ index c555cd30f..eb3db0f19 100644
|
||||||
pool_size -= b.size;
|
pool_size -= b.size;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -420,8 +427,15 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
|
@@ -429,8 +436,15 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
|
||||||
void * ptr;
|
void * ptr;
|
||||||
size_t look_ahead_size = (size_t) (1.05 * size);
|
size_t look_ahead_size = (size_t) (1.05 * size);
|
||||||
look_ahead_size = 256 * ((look_ahead_size + 255)/256);
|
look_ahead_size = 256 * ((look_ahead_size + 255)/256);
|
||||||
|
|
@ -366,7 +366,7 @@ index c555cd30f..eb3db0f19 100644
|
||||||
*actual_size = look_ahead_size;
|
*actual_size = look_ahead_size;
|
||||||
pool_size += look_ahead_size;
|
pool_size += look_ahead_size;
|
||||||
#ifdef DEBUG_CUDA_MALLOC
|
#ifdef DEBUG_CUDA_MALLOC
|
||||||
@@ -441,10 +455,20 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
|
@@ -450,10 +464,20 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
GGML_LOG_DEBUG(GGML_CUDA_NAME " buffer pool full, increase MAX_CUDA_BUFFERS\n");
|
GGML_LOG_DEBUG(GGML_CUDA_NAME " buffer pool full, increase MAX_CUDA_BUFFERS\n");
|
||||||
|
|
@ -389,7 +389,7 @@ index c555cd30f..eb3db0f19 100644
|
||||||
};
|
};
|
||||||
|
|
||||||
// pool with virtual memory
|
// pool with virtual memory
|
||||||
@@ -456,18 +480,24 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
|
@@ -465,18 +489,24 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
|
||||||
CUdeviceptr pool_addr = 0;
|
CUdeviceptr pool_addr = 0;
|
||||||
size_t pool_used = 0;
|
size_t pool_used = 0;
|
||||||
size_t pool_size = 0;
|
size_t pool_size = 0;
|
||||||
|
|
@ -417,7 +417,7 @@ index c555cd30f..eb3db0f19 100644
|
||||||
#if defined(GGML_USE_HIP)
|
#if defined(GGML_USE_HIP)
|
||||||
// Workaround for https://github.com/ROCm/ROCR-Runtime/issues/285
|
// Workaround for https://github.com/ROCm/ROCR-Runtime/issues/285
|
||||||
for (std::pair<CUdeviceptr, size_t> & mapping : mappings) {
|
for (std::pair<CUdeviceptr, size_t> & mapping : mappings) {
|
||||||
@@ -494,35 +524,49 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
|
@@ -503,35 +533,49 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
|
||||||
|
|
||||||
GGML_ASSERT(pool_size + reserve_size <= CUDA_POOL_VMM_MAX_SIZE);
|
GGML_ASSERT(pool_size + reserve_size <= CUDA_POOL_VMM_MAX_SIZE);
|
||||||
|
|
||||||
|
|
@ -493,7 +493,7 @@ index c555cd30f..eb3db0f19 100644
|
||||||
|
|
||||||
// add to the pool
|
// add to the pool
|
||||||
pool_size += reserve_size;
|
pool_size += reserve_size;
|
||||||
@@ -555,16 +599,24 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
|
@@ -564,16 +608,24 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
|
||||||
// all deallocations must be in reverse order of the allocations
|
// all deallocations must be in reverse order of the allocations
|
||||||
GGML_ASSERT(ptr == (void *) ((char *)(pool_addr) + pool_used));
|
GGML_ASSERT(ptr == (void *) ((char *)(pool_addr) + pool_used));
|
||||||
}
|
}
|
||||||
|
|
@ -521,7 +521,7 @@ index c555cd30f..eb3db0f19 100644
|
||||||
}
|
}
|
||||||
|
|
||||||
// destroying a cuBLAS handle while a graph is being captured in a different thread can result in a CUDA error
|
// destroying a cuBLAS handle while a graph is being captured in a different thread can result in a CUDA error
|
||||||
@@ -748,11 +800,20 @@ static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffer(ggml_bac
|
@@ -757,11 +809,20 @@ static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffer(ggml_bac
|
||||||
}
|
}
|
||||||
|
|
||||||
static size_t ggml_backend_cuda_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
|
static size_t ggml_backend_cuda_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
|
||||||
|
|
@ -543,7 +543,7 @@ index c555cd30f..eb3db0f19 100644
|
||||||
static size_t ggml_backend_cuda_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
|
static size_t ggml_backend_cuda_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
|
||||||
size_t size = ggml_nbytes(tensor);
|
size_t size = ggml_nbytes(tensor);
|
||||||
int64_t ne0 = tensor->ne[0];
|
int64_t ne0 = tensor->ne[0];
|
||||||
@@ -776,6 +837,7 @@ static const ggml_backend_buffer_type_i ggml_backend_cuda_buffer_type_interface
|
@@ -785,6 +846,7 @@ static const ggml_backend_buffer_type_i ggml_backend_cuda_buffer_type_interface
|
||||||
/* .get_max_size = */ NULL, // defaults to SIZE_MAX
|
/* .get_max_size = */ NULL, // defaults to SIZE_MAX
|
||||||
/* .get_alloc_size = */ ggml_backend_cuda_buffer_type_get_alloc_size,
|
/* .get_alloc_size = */ ggml_backend_cuda_buffer_type_get_alloc_size,
|
||||||
/* .is_host = */ NULL,
|
/* .is_host = */ NULL,
|
||||||
|
|
@ -551,7 +551,7 @@ index c555cd30f..eb3db0f19 100644
|
||||||
};
|
};
|
||||||
|
|
||||||
ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device) {
|
ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device) {
|
||||||
@@ -3003,6 +3065,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
|
@@ -2986,6 +3048,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
|
||||||
|
|
||||||
static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
|
static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
|
||||||
bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) {
|
bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) {
|
||||||
|
|
@ -559,7 +559,7 @@ index c555cd30f..eb3db0f19 100644
|
||||||
// flag used to determine whether it is an integrated_gpu
|
// flag used to determine whether it is an integrated_gpu
|
||||||
const bool integrated = ggml_cuda_info().devices[cuda_ctx->device].integrated;
|
const bool integrated = ggml_cuda_info().devices[cuda_ctx->device].integrated;
|
||||||
|
|
||||||
@@ -3018,6 +3081,11 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
|
@@ -3001,6 +3064,11 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -571,7 +571,7 @@ index c555cd30f..eb3db0f19 100644
|
||||||
static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
|
static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
|
||||||
if (!disable_fusion) {
|
if (!disable_fusion) {
|
||||||
|
|
||||||
@@ -3144,6 +3212,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
|
@@ -3140,6 +3208,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
|
||||||
|
|
||||||
static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph, int batch_size) {
|
static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph, int batch_size) {
|
||||||
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
|
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
|
||||||
|
|
@ -579,7 +579,7 @@ index c555cd30f..eb3db0f19 100644
|
||||||
|
|
||||||
ggml_cuda_set_device(cuda_ctx->device);
|
ggml_cuda_set_device(cuda_ctx->device);
|
||||||
|
|
||||||
@@ -3223,6 +3292,71 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
|
@@ -3215,6 +3284,71 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
|
||||||
return GGML_STATUS_SUCCESS;
|
return GGML_STATUS_SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -651,7 +651,7 @@ index c555cd30f..eb3db0f19 100644
|
||||||
static void ggml_backend_cuda_event_record(ggml_backend_t backend, ggml_backend_event_t event) {
|
static void ggml_backend_cuda_event_record(ggml_backend_t backend, ggml_backend_event_t event) {
|
||||||
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
|
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
|
||||||
|
|
||||||
@@ -3263,6 +3397,9 @@ static const ggml_backend_i ggml_backend_cuda_interface = {
|
@@ -3255,6 +3389,9 @@ static const ggml_backend_i ggml_backend_cuda_interface = {
|
||||||
/* .event_record = */ ggml_backend_cuda_event_record,
|
/* .event_record = */ ggml_backend_cuda_event_record,
|
||||||
/* .event_wait = */ ggml_backend_cuda_event_wait,
|
/* .event_wait = */ ggml_backend_cuda_event_wait,
|
||||||
/* .graph_optimize = */ NULL,
|
/* .graph_optimize = */ NULL,
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ Subject: [PATCH] decode: disable output_all
|
||||||
1 file changed, 1 insertion(+), 2 deletions(-)
|
1 file changed, 1 insertion(+), 2 deletions(-)
|
||||||
|
|
||||||
diff --git a/src/llama-context.cpp b/src/llama-context.cpp
|
diff --git a/src/llama-context.cpp b/src/llama-context.cpp
|
||||||
index e7526e7d..53a5e3a9 100644
|
index bd348bcad..8b4a89d38 100644
|
||||||
--- a/src/llama-context.cpp
|
--- a/src/llama-context.cpp
|
||||||
+++ b/src/llama-context.cpp
|
+++ b/src/llama-context.cpp
|
||||||
@@ -974,8 +974,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
@@ -974,8 +974,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ unused then it can be reset to free these data structures.
|
||||||
6 files changed, 32 insertions(+), 2 deletions(-)
|
6 files changed, 32 insertions(+), 2 deletions(-)
|
||||||
|
|
||||||
diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h
|
diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h
|
||||||
index 1ff53ed03..ba181d09d 100644
|
index b3b5b356a..69223c488 100644
|
||||||
--- a/ggml/include/ggml-backend.h
|
--- a/ggml/include/ggml-backend.h
|
||||||
+++ b/ggml/include/ggml-backend.h
|
+++ b/ggml/include/ggml-backend.h
|
||||||
@@ -178,6 +178,7 @@ extern "C" {
|
@@ -178,6 +178,7 @@ extern "C" {
|
||||||
|
|
@ -28,7 +28,7 @@ index 1ff53ed03..ba181d09d 100644
|
||||||
GGML_API ggml_backend_buffer_type_t ggml_backend_dev_host_buffer_type(ggml_backend_dev_t device);
|
GGML_API ggml_backend_buffer_type_t ggml_backend_dev_host_buffer_type(ggml_backend_dev_t device);
|
||||||
GGML_API ggml_backend_buffer_t ggml_backend_dev_buffer_from_host_ptr(ggml_backend_dev_t device, void * ptr, size_t size, size_t max_tensor_size);
|
GGML_API ggml_backend_buffer_t ggml_backend_dev_buffer_from_host_ptr(ggml_backend_dev_t device, void * ptr, size_t size, size_t max_tensor_size);
|
||||||
diff --git a/ggml/src/ggml-backend-impl.h b/ggml/src/ggml-backend-impl.h
|
diff --git a/ggml/src/ggml-backend-impl.h b/ggml/src/ggml-backend-impl.h
|
||||||
index 3c3f22fc0..43c91d9f2 100644
|
index 7bdf9d81f..21b35ac5c 100644
|
||||||
--- a/ggml/src/ggml-backend-impl.h
|
--- a/ggml/src/ggml-backend-impl.h
|
||||||
+++ b/ggml/src/ggml-backend-impl.h
|
+++ b/ggml/src/ggml-backend-impl.h
|
||||||
@@ -195,6 +195,10 @@ extern "C" {
|
@@ -195,6 +195,10 @@ extern "C" {
|
||||||
|
|
@ -43,7 +43,7 @@ index 3c3f22fc0..43c91d9f2 100644
|
||||||
|
|
||||||
struct ggml_backend_device {
|
struct ggml_backend_device {
|
||||||
diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp
|
diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp
|
||||||
index 6ef5eeafa..0b757af59 100644
|
index c81a2e48a..9b0a9b91f 100644
|
||||||
--- a/ggml/src/ggml-backend.cpp
|
--- a/ggml/src/ggml-backend.cpp
|
||||||
+++ b/ggml/src/ggml-backend.cpp
|
+++ b/ggml/src/ggml-backend.cpp
|
||||||
@@ -526,6 +526,14 @@ ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * par
|
@@ -526,6 +526,14 @@ ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * par
|
||||||
|
|
@ -62,7 +62,7 @@ index 6ef5eeafa..0b757af59 100644
|
||||||
GGML_ASSERT(device);
|
GGML_ASSERT(device);
|
||||||
return device->iface.get_buffer_type(device);
|
return device->iface.get_buffer_type(device);
|
||||||
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
|
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||||
index 811462c79..87c6c34a4 100644
|
index f79e5d65c..c9333689f 100644
|
||||||
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
|
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||||
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
|
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||||
@@ -107,6 +107,11 @@ int ggml_cuda_get_device() {
|
@@ -107,6 +107,11 @@ int ggml_cuda_get_device() {
|
||||||
|
|
@ -77,7 +77,7 @@ index 811462c79..87c6c34a4 100644
|
||||||
static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {
|
static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {
|
||||||
ggml_cuda_set_device(device);
|
ggml_cuda_set_device(device);
|
||||||
cudaError_t err;
|
cudaError_t err;
|
||||||
@@ -3515,7 +3520,10 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back
|
@@ -3499,7 +3504,10 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back
|
||||||
props->id = ggml_backend_cuda_device_get_id(dev);
|
props->id = ggml_backend_cuda_device_get_id(dev);
|
||||||
props->type = ggml_backend_cuda_device_get_type(dev);
|
props->type = ggml_backend_cuda_device_get_type(dev);
|
||||||
props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str();
|
props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str();
|
||||||
|
|
@ -89,7 +89,7 @@ index 811462c79..87c6c34a4 100644
|
||||||
|
|
||||||
bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr;
|
bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr;
|
||||||
#ifdef GGML_CUDA_NO_PEER_COPY
|
#ifdef GGML_CUDA_NO_PEER_COPY
|
||||||
@@ -3948,6 +3956,11 @@ static void ggml_backend_cuda_device_event_synchronize(ggml_backend_dev_t dev, g
|
@@ -3936,6 +3944,11 @@ static void ggml_backend_cuda_device_event_synchronize(ggml_backend_dev_t dev, g
|
||||||
CUDA_CHECK(cudaEventSynchronize((cudaEvent_t)event->context));
|
CUDA_CHECK(cudaEventSynchronize((cudaEvent_t)event->context));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -101,7 +101,7 @@ index 811462c79..87c6c34a4 100644
|
||||||
static const ggml_backend_device_i ggml_backend_cuda_device_interface = {
|
static const ggml_backend_device_i ggml_backend_cuda_device_interface = {
|
||||||
/* .get_name = */ ggml_backend_cuda_device_get_name,
|
/* .get_name = */ ggml_backend_cuda_device_get_name,
|
||||||
/* .get_description = */ ggml_backend_cuda_device_get_description,
|
/* .get_description = */ ggml_backend_cuda_device_get_description,
|
||||||
@@ -3964,6 +3977,7 @@ static const ggml_backend_device_i ggml_backend_cuda_device_interface = {
|
@@ -3952,6 +3965,7 @@ static const ggml_backend_device_i ggml_backend_cuda_device_interface = {
|
||||||
/* .event_new = */ ggml_backend_cuda_device_event_new,
|
/* .event_new = */ ggml_backend_cuda_device_event_new,
|
||||||
/* .event_free = */ ggml_backend_cuda_device_event_free,
|
/* .event_free = */ ggml_backend_cuda_device_event_free,
|
||||||
/* .event_synchronize = */ ggml_backend_cuda_device_event_synchronize,
|
/* .event_synchronize = */ ggml_backend_cuda_device_event_synchronize,
|
||||||
|
|
@ -122,10 +122,10 @@ index 890c10364..1f06be80e 100644
|
||||||
#define cudaError_t hipError_t
|
#define cudaError_t hipError_t
|
||||||
#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled
|
#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled
|
||||||
diff --git a/src/llama.cpp b/src/llama.cpp
|
diff --git a/src/llama.cpp b/src/llama.cpp
|
||||||
index fe5a7a835..d821a96a0 100644
|
index ab2e9868a..74c49e651 100644
|
||||||
--- a/src/llama.cpp
|
--- a/src/llama.cpp
|
||||||
+++ b/src/llama.cpp
|
+++ b/src/llama.cpp
|
||||||
@@ -267,10 +267,12 @@ static struct llama_model * llama_model_load_from_file_impl(
|
@@ -270,10 +270,12 @@ static struct llama_model * llama_model_load_from_file_impl(
|
||||||
for (auto * dev : model->devices) {
|
for (auto * dev : model->devices) {
|
||||||
ggml_backend_dev_props props;
|
ggml_backend_dev_props props;
|
||||||
ggml_backend_dev_get_props(dev, &props);
|
ggml_backend_dev_get_props(dev, &props);
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ Subject: [PATCH] harden uncaught exception registration
|
||||||
1 file changed, 6 insertions(+), 2 deletions(-)
|
1 file changed, 6 insertions(+), 2 deletions(-)
|
||||||
|
|
||||||
diff --git a/ggml/src/ggml.cpp b/ggml/src/ggml.cpp
|
diff --git a/ggml/src/ggml.cpp b/ggml/src/ggml.cpp
|
||||||
index 0d388d45..f5bcb446 100644
|
index 0d388d455..f5bcb446d 100644
|
||||||
--- a/ggml/src/ggml.cpp
|
--- a/ggml/src/ggml.cpp
|
||||||
+++ b/ggml/src/ggml.cpp
|
+++ b/ggml/src/ggml.cpp
|
||||||
@@ -19,8 +19,12 @@ static bool ggml_uncaught_exception_init = []{
|
@@ -19,8 +19,12 @@ static bool ggml_uncaught_exception_init = []{
|
||||||
|
|
|
||||||
|
|
@ -45,7 +45,7 @@ index 69223c488..6510e0cba 100644
|
||||||
|
|
||||||
GGML_API const char * ggml_backend_dev_name(ggml_backend_dev_t device);
|
GGML_API const char * ggml_backend_dev_name(ggml_backend_dev_t device);
|
||||||
diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt
|
diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt
|
||||||
index 0609c6503..aefe43bdd 100644
|
index f9a6587f1..03f359ae9 100644
|
||||||
--- a/ggml/src/CMakeLists.txt
|
--- a/ggml/src/CMakeLists.txt
|
||||||
+++ b/ggml/src/CMakeLists.txt
|
+++ b/ggml/src/CMakeLists.txt
|
||||||
@@ -209,6 +209,8 @@ add_library(ggml-base
|
@@ -209,6 +209,8 @@ add_library(ggml-base
|
||||||
|
|
@ -58,7 +58,7 @@ index 0609c6503..aefe43bdd 100644
|
||||||
|
|
||||||
target_include_directories(ggml-base PRIVATE .)
|
target_include_directories(ggml-base PRIVATE .)
|
||||||
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
|
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||||
index 5787e8cd5..d232bf828 100644
|
index c9333689f..41b00af83 100644
|
||||||
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
|
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||||
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
|
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||||
@@ -261,6 +261,16 @@ static ggml_cuda_device_info ggml_cuda_init() {
|
@@ -261,6 +261,16 @@ static ggml_cuda_device_info ggml_cuda_init() {
|
||||||
|
|
@ -90,7 +90,7 @@ index 5787e8cd5..d232bf828 100644
|
||||||
GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s, ID: %s\n",
|
GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s, ID: %s\n",
|
||||||
id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no",
|
id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no",
|
||||||
ggml_cuda_parse_uuid(prop, id).c_str());
|
ggml_cuda_parse_uuid(prop, id).c_str());
|
||||||
@@ -3476,6 +3491,11 @@ struct ggml_backend_cuda_device_context {
|
@@ -3468,6 +3483,11 @@ struct ggml_backend_cuda_device_context {
|
||||||
std::string description;
|
std::string description;
|
||||||
std::string pci_bus_id;
|
std::string pci_bus_id;
|
||||||
std::string id;
|
std::string id;
|
||||||
|
|
@ -102,7 +102,7 @@ index 5787e8cd5..d232bf828 100644
|
||||||
};
|
};
|
||||||
|
|
||||||
static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) {
|
static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) {
|
||||||
@@ -3496,6 +3516,28 @@ static const char * ggml_backend_cuda_device_get_id(ggml_backend_dev_t dev) {
|
@@ -3488,6 +3508,28 @@ static const char * ggml_backend_cuda_device_get_id(ggml_backend_dev_t dev) {
|
||||||
static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
|
static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
|
||||||
ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
|
ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
|
||||||
ggml_cuda_set_device(ctx->device);
|
ggml_cuda_set_device(ctx->device);
|
||||||
|
|
@ -131,7 +131,7 @@ index 5787e8cd5..d232bf828 100644
|
||||||
CUDA_CHECK(cudaMemGetInfo(free, total));
|
CUDA_CHECK(cudaMemGetInfo(free, total));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -3504,6 +3546,7 @@ static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend
|
@@ -3496,6 +3538,7 @@ static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend
|
||||||
return GGML_BACKEND_DEVICE_TYPE_GPU;
|
return GGML_BACKEND_DEVICE_TYPE_GPU;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -139,7 +139,7 @@ index 5787e8cd5..d232bf828 100644
|
||||||
static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
|
static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
|
||||||
ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
|
ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
|
||||||
|
|
||||||
@@ -3517,6 +3560,19 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back
|
@@ -3509,6 +3552,19 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back
|
||||||
// If you need the memory data, call ggml_backend_dev_memory() explicitly.
|
// If you need the memory data, call ggml_backend_dev_memory() explicitly.
|
||||||
props->memory_total = props->memory_free = 0;
|
props->memory_total = props->memory_free = 0;
|
||||||
|
|
||||||
|
|
@ -159,7 +159,7 @@ index 5787e8cd5..d232bf828 100644
|
||||||
bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr;
|
bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr;
|
||||||
#ifdef GGML_CUDA_NO_PEER_COPY
|
#ifdef GGML_CUDA_NO_PEER_COPY
|
||||||
bool events = false;
|
bool events = false;
|
||||||
@@ -4079,6 +4135,7 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
|
@@ -4075,6 +4131,7 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
|
||||||
std::lock_guard<std::mutex> lock(mutex);
|
std::lock_guard<std::mutex> lock(mutex);
|
||||||
if (!initialized) {
|
if (!initialized) {
|
||||||
ggml_backend_cuda_reg_context * ctx = new ggml_backend_cuda_reg_context;
|
ggml_backend_cuda_reg_context * ctx = new ggml_backend_cuda_reg_context;
|
||||||
|
|
@ -167,7 +167,7 @@ index 5787e8cd5..d232bf828 100644
|
||||||
|
|
||||||
for (int i = 0; i < ggml_cuda_info().device_count; i++) {
|
for (int i = 0; i < ggml_cuda_info().device_count; i++) {
|
||||||
ggml_backend_cuda_device_context * dev_ctx = new ggml_backend_cuda_device_context;
|
ggml_backend_cuda_device_context * dev_ctx = new ggml_backend_cuda_device_context;
|
||||||
@@ -4094,6 +4151,14 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
|
@@ -4090,6 +4147,14 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
|
||||||
snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.0", prop.pciDomainID, prop.pciBusID, prop.pciDeviceID);
|
snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.0", prop.pciDomainID, prop.pciBusID, prop.pciDeviceID);
|
||||||
dev_ctx->pci_bus_id = pci_bus_id;
|
dev_ctx->pci_bus_id = pci_bus_id;
|
||||||
|
|
||||||
|
|
@ -204,11 +204,11 @@ index 1f06be80e..2f9ef2dc0 100644
|
||||||
#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled
|
#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled
|
||||||
#define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled
|
#define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled
|
||||||
diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h
|
diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h
|
||||||
index d0fb3bcca..b63edd0c1 100644
|
index e9201cdc6..44ae76d66 100644
|
||||||
--- a/ggml/src/ggml-impl.h
|
--- a/ggml/src/ggml-impl.h
|
||||||
+++ b/ggml/src/ggml-impl.h
|
+++ b/ggml/src/ggml-impl.h
|
||||||
@@ -638,6 +638,14 @@ static inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx
|
@@ -677,6 +677,14 @@ static inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph,
|
||||||
return ggml_can_fuse_ext(cgraph, idxs, ops, num_ops);
|
return ggml_can_fuse_subgraph_ext(cgraph, idxs, count, ops, outputs, num_outputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
+// Management libraries for fetching more accurate free VRAM data
|
+// Management libraries for fetching more accurate free VRAM data
|
||||||
|
|
@ -243,10 +243,10 @@ index 05ff6a5a6..032dee76d 100644
|
||||||
/* .async = */ true,
|
/* .async = */ true,
|
||||||
/* .host_buffer = */ false,
|
/* .host_buffer = */ false,
|
||||||
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||||
index bd3ece516..7cfb14a54 100644
|
index 3a6bbe564..d2c278a35 100644
|
||||||
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||||
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||||
@@ -231,6 +231,7 @@ class vk_memory_logger;
|
@@ -229,6 +229,7 @@ class vk_memory_logger;
|
||||||
#endif
|
#endif
|
||||||
class vk_perf_logger;
|
class vk_perf_logger;
|
||||||
static void ggml_vk_destroy_buffer(vk_buffer& buf);
|
static void ggml_vk_destroy_buffer(vk_buffer& buf);
|
||||||
|
|
@ -254,7 +254,7 @@ index bd3ece516..7cfb14a54 100644
|
||||||
|
|
||||||
static constexpr uint32_t mul_mat_vec_max_cols = 8;
|
static constexpr uint32_t mul_mat_vec_max_cols = 8;
|
||||||
static constexpr uint32_t p021_max_gqa_ratio = 8;
|
static constexpr uint32_t p021_max_gqa_ratio = 8;
|
||||||
@@ -11585,6 +11586,29 @@ static void ggml_vk_get_device_description(int device, char * description, size_
|
@@ -11813,6 +11814,29 @@ static void ggml_vk_get_device_description(int device, char * description, size_
|
||||||
snprintf(description, description_size, "%s", props.deviceName.data());
|
snprintf(description, description_size, "%s", props.deviceName.data());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -284,7 +284,7 @@ index bd3ece516..7cfb14a54 100644
|
||||||
// backend interface
|
// backend interface
|
||||||
|
|
||||||
#define UNUSED GGML_UNUSED
|
#define UNUSED GGML_UNUSED
|
||||||
@@ -12392,31 +12416,102 @@ void ggml_backend_vk_get_device_description(int device, char * description, size
|
@@ -12761,31 +12785,102 @@ void ggml_backend_vk_get_device_description(int device, char * description, size
|
||||||
ggml_vk_get_device_description(dev_idx, description, description_size);
|
ggml_vk_get_device_description(dev_idx, description, description_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -404,7 +404,7 @@ index bd3ece516..7cfb14a54 100644
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -12449,8 +12544,13 @@ static std::string ggml_backend_vk_get_device_pci_id(int device_idx) {
|
@@ -12818,8 +12913,13 @@ static std::string ggml_backend_vk_get_device_pci_id(int device_idx) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -419,7 +419,7 @@ index bd3ece516..7cfb14a54 100644
|
||||||
}
|
}
|
||||||
|
|
||||||
vk::PhysicalDeviceProperties2 props = {};
|
vk::PhysicalDeviceProperties2 props = {};
|
||||||
@@ -12467,19 +12567,24 @@ static std::string ggml_backend_vk_get_device_pci_id(int device_idx) {
|
@@ -12836,19 +12936,24 @@ static std::string ggml_backend_vk_get_device_pci_id(int device_idx) {
|
||||||
|
|
||||||
char pci_bus_id[16] = {};
|
char pci_bus_id[16] = {};
|
||||||
snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.%x", pci_domain, pci_bus, pci_device, pci_function);
|
snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.%x", pci_domain, pci_bus, pci_device, pci_function);
|
||||||
|
|
@ -453,7 +453,7 @@ index bd3ece516..7cfb14a54 100644
|
||||||
|
|
||||||
static const char * ggml_backend_vk_device_get_name(ggml_backend_dev_t dev) {
|
static const char * ggml_backend_vk_device_get_name(ggml_backend_dev_t dev) {
|
||||||
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
||||||
@@ -12491,9 +12596,14 @@ static const char * ggml_backend_vk_device_get_description(ggml_backend_dev_t de
|
@@ -12860,9 +12965,14 @@ static const char * ggml_backend_vk_device_get_description(ggml_backend_dev_t de
|
||||||
return ctx->description.c_str();
|
return ctx->description.c_str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -469,7 +469,7 @@ index bd3ece516..7cfb14a54 100644
|
||||||
}
|
}
|
||||||
|
|
||||||
static ggml_backend_buffer_type_t ggml_backend_vk_device_get_buffer_type(ggml_backend_dev_t dev) {
|
static ggml_backend_buffer_type_t ggml_backend_vk_device_get_buffer_type(ggml_backend_dev_t dev) {
|
||||||
@@ -12517,8 +12627,9 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml
|
@@ -12886,8 +12996,9 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml
|
||||||
|
|
||||||
props->name = ggml_backend_vk_device_get_name(dev);
|
props->name = ggml_backend_vk_device_get_name(dev);
|
||||||
props->description = ggml_backend_vk_device_get_description(dev);
|
props->description = ggml_backend_vk_device_get_description(dev);
|
||||||
|
|
@ -480,7 +480,7 @@ index bd3ece516..7cfb14a54 100644
|
||||||
ggml_backend_vk_device_get_memory(dev, &props->memory_free, &props->memory_total);
|
ggml_backend_vk_device_get_memory(dev, &props->memory_free, &props->memory_total);
|
||||||
props->caps = {
|
props->caps = {
|
||||||
/* .async = */ false,
|
/* .async = */ false,
|
||||||
@@ -12526,6 +12637,13 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml
|
@@ -12895,6 +13006,13 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml
|
||||||
/* .buffer_from_host_ptr = */ false,
|
/* .buffer_from_host_ptr = */ false,
|
||||||
/* .events = */ false,
|
/* .events = */ false,
|
||||||
};
|
};
|
||||||
|
|
@ -494,7 +494,7 @@ index bd3ece516..7cfb14a54 100644
|
||||||
}
|
}
|
||||||
|
|
||||||
static ggml_backend_t ggml_backend_vk_device_init(ggml_backend_dev_t dev, const char * params) {
|
static ggml_backend_t ggml_backend_vk_device_init(ggml_backend_dev_t dev, const char * params) {
|
||||||
@@ -12954,6 +13072,8 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg,
|
@@ -13365,6 +13483,8 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg,
|
||||||
static std::mutex mutex;
|
static std::mutex mutex;
|
||||||
std::lock_guard<std::mutex> lock(mutex);
|
std::lock_guard<std::mutex> lock(mutex);
|
||||||
if (!initialized) {
|
if (!initialized) {
|
||||||
|
|
@ -503,7 +503,7 @@ index bd3ece516..7cfb14a54 100644
|
||||||
for (int i = 0; i < ggml_backend_vk_get_device_count(); i++) {
|
for (int i = 0; i < ggml_backend_vk_get_device_count(); i++) {
|
||||||
ggml_backend_vk_device_context * ctx = new ggml_backend_vk_device_context;
|
ggml_backend_vk_device_context * ctx = new ggml_backend_vk_device_context;
|
||||||
char desc[256];
|
char desc[256];
|
||||||
@@ -12962,12 +13082,41 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg,
|
@@ -13373,12 +13493,41 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg,
|
||||||
ctx->name = GGML_VK_NAME + std::to_string(i);
|
ctx->name = GGML_VK_NAME + std::to_string(i);
|
||||||
ctx->description = desc;
|
ctx->description = desc;
|
||||||
ctx->is_integrated_gpu = ggml_backend_vk_get_device_type(i) == vk::PhysicalDeviceType::eIntegratedGpu;
|
ctx->is_integrated_gpu = ggml_backend_vk_get_device_type(i) == vk::PhysicalDeviceType::eIntegratedGpu;
|
||||||
|
|
|
||||||
|
|
@ -1,49 +0,0 @@
|
||||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
|
||||||
From: Julius Tischbein <ju.tischbein@gmail.com>
|
|
||||||
Date: Wed, 15 Oct 2025 13:54:15 +0200
|
|
||||||
Subject: [PATCH] CUDA: Changing the CUDA scheduling strategy to spin (#16585)
|
|
||||||
MIME-Version: 1.0
|
|
||||||
Content-Type: text/plain; charset=UTF-8
|
|
||||||
Content-Transfer-Encoding: 8bit
|
|
||||||
|
|
||||||
* CUDA set scheduling strategy to spinning for cc121
|
|
||||||
|
|
||||||
* Using prop.major and prop.minor, include HIP and MUSA
|
|
||||||
|
|
||||||
* Exclude HIP and MUSA
|
|
||||||
|
|
||||||
* Remove trailing whitespace
|
|
||||||
|
|
||||||
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
|
|
||||||
|
|
||||||
* Remove empty line
|
|
||||||
|
|
||||||
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
|
|
||||||
|
|
||||||
---------
|
|
||||||
|
|
||||||
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
|
|
||||||
---
|
|
||||||
ggml/src/ggml-cuda/ggml-cuda.cu | 9 +++++++++
|
|
||||||
1 file changed, 9 insertions(+)
|
|
||||||
|
|
||||||
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
|
|
||||||
index b075a18be..d62f412d6 100644
|
|
||||||
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
|
|
||||||
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
|
|
||||||
@@ -340,6 +340,15 @@ static ggml_cuda_device_info ggml_cuda_init() {
|
|
||||||
} else if (device_name.substr(0, 21) == "NVIDIA GeForce GTX 16") {
|
|
||||||
turing_devices_without_mma.push_back({ id, device_name });
|
|
||||||
}
|
|
||||||
+
|
|
||||||
+ // Temporary performance fix:
|
|
||||||
+ // Setting device scheduling strategy for iGPUs with cc121 to "spinning" to avoid delays in cuda synchronize calls.
|
|
||||||
+ // TODO: Check for future drivers the default scheduling strategy and
|
|
||||||
+ // remove this call again when cudaDeviceScheduleSpin is default.
|
|
||||||
+ if (prop.major == 12 && prop.minor == 1) {
|
|
||||||
+ CUDA_CHECK(cudaSetDeviceFlags(cudaDeviceScheduleSpin));
|
|
||||||
+ }
|
|
||||||
+
|
|
||||||
#endif // defined(GGML_USE_HIP)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
@ -8,10 +8,10 @@ Subject: [PATCH] report LoadLibrary failures
|
||||||
1 file changed, 12 insertions(+)
|
1 file changed, 12 insertions(+)
|
||||||
|
|
||||||
diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp
|
diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp
|
||||||
index f794d9cfa..3a855ab2e 100644
|
index a55d9b280..ec6f7f1e9 100644
|
||||||
--- a/ggml/src/ggml-backend-reg.cpp
|
--- a/ggml/src/ggml-backend-reg.cpp
|
||||||
+++ b/ggml/src/ggml-backend-reg.cpp
|
+++ b/ggml/src/ggml-backend-reg.cpp
|
||||||
@@ -118,6 +118,18 @@ static dl_handle * dl_load_library(const fs::path & path) {
|
@@ -122,6 +122,18 @@ static dl_handle * dl_load_library(const fs::path & path) {
|
||||||
SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
|
SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
|
||||||
|
|
||||||
HMODULE handle = LoadLibraryW(path.wstring().c_str());
|
HMODULE handle = LoadLibraryW(path.wstring().c_str());
|
||||||
|
|
@ -13,7 +13,7 @@ interleaved version used for qwen3vl
|
||||||
4 files changed, 11 insertions(+), 30 deletions(-)
|
4 files changed, 11 insertions(+), 30 deletions(-)
|
||||||
|
|
||||||
diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp
|
diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp
|
||||||
index 31478dd8e..4d1ed207e 100644
|
index 902fdad69..70955347d 100644
|
||||||
--- a/ggml/src/ggml-cpu/ops.cpp
|
--- a/ggml/src/ggml-cpu/ops.cpp
|
||||||
+++ b/ggml/src/ggml-cpu/ops.cpp
|
+++ b/ggml/src/ggml-cpu/ops.cpp
|
||||||
@@ -5509,15 +5509,12 @@ static void ggml_mrope_cache_init(
|
@@ -5509,15 +5509,12 @@ static void ggml_mrope_cache_init(
|
||||||
|
|
@ -62,10 +62,10 @@ index d058504cd..287fe9d2c 100644
|
||||||
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
|
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
|
||||||
|
|
||||||
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
|
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
|
||||||
index 375a0c7fd..9866c96b4 100644
|
index 50b8071de..65a3183c8 100644
|
||||||
--- a/ggml/src/ggml-metal/ggml-metal.metal
|
--- a/ggml/src/ggml-metal/ggml-metal.metal
|
||||||
+++ b/ggml/src/ggml-metal/ggml-metal.metal
|
+++ b/ggml/src/ggml-metal/ggml-metal.metal
|
||||||
@@ -3858,15 +3858,11 @@ kernel void kernel_rope_multi(
|
@@ -3888,15 +3888,11 @@ kernel void kernel_rope_multi(
|
||||||
const int sec_w012 = args.sect_0 + args.sect_1 + args.sect_2; // end of section 2
|
const int sec_w012 = args.sect_0 + args.sect_1 + args.sect_2; // end of section 2
|
||||||
const int sector = ic % sect_dims;
|
const int sector = ic % sect_dims;
|
||||||
|
|
||||||
|
|
@ -12,7 +12,7 @@ Subject: [PATCH] Add memory detection using DXGI + PDH
|
||||||
create mode 100644 ggml/src/mem_dxgi_pdh.cpp
|
create mode 100644 ggml/src/mem_dxgi_pdh.cpp
|
||||||
|
|
||||||
diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt
|
diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt
|
||||||
index aefe43bdd..21fe4640c 100644
|
index 03f359ae9..4b3e5efb5 100644
|
||||||
--- a/ggml/src/CMakeLists.txt
|
--- a/ggml/src/CMakeLists.txt
|
||||||
+++ b/ggml/src/CMakeLists.txt
|
+++ b/ggml/src/CMakeLists.txt
|
||||||
@@ -211,6 +211,7 @@ add_library(ggml-base
|
@@ -211,6 +211,7 @@ add_library(ggml-base
|
||||||
|
|
@ -24,10 +24,10 @@ index aefe43bdd..21fe4640c 100644
|
||||||
|
|
||||||
target_include_directories(ggml-base PRIVATE .)
|
target_include_directories(ggml-base PRIVATE .)
|
||||||
diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h
|
diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h
|
||||||
index b63edd0c1..81cad8cf3 100644
|
index 44ae76d66..639d551a2 100644
|
||||||
--- a/ggml/src/ggml-impl.h
|
--- a/ggml/src/ggml-impl.h
|
||||||
+++ b/ggml/src/ggml-impl.h
|
+++ b/ggml/src/ggml-impl.h
|
||||||
@@ -645,6 +645,9 @@ GGML_API void ggml_nvml_release();
|
@@ -684,6 +684,9 @@ GGML_API void ggml_nvml_release();
|
||||||
GGML_API int ggml_hip_mgmt_init();
|
GGML_API int ggml_hip_mgmt_init();
|
||||||
GGML_API int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total);
|
GGML_API int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total);
|
||||||
GGML_API void ggml_hip_mgmt_release();
|
GGML_API void ggml_hip_mgmt_release();
|
||||||
|
|
@ -38,7 +38,7 @@ index b63edd0c1..81cad8cf3 100644
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||||
index 7cfb14a54..a1c46d0b3 100644
|
index d2c278a35..221e29509 100644
|
||||||
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||||
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||||
@@ -73,6 +73,7 @@ DispatchLoaderDynamic & ggml_vk_default_dispatcher();
|
@@ -73,6 +73,7 @@ DispatchLoaderDynamic & ggml_vk_default_dispatcher();
|
||||||
|
|
@ -49,7 +49,7 @@ index 7cfb14a54..a1c46d0b3 100644
|
||||||
|
|
||||||
typedef struct VkPhysicalDeviceShaderBfloat16FeaturesKHR {
|
typedef struct VkPhysicalDeviceShaderBfloat16FeaturesKHR {
|
||||||
VkStructureType sType;
|
VkStructureType sType;
|
||||||
@@ -12433,6 +12434,7 @@ struct ggml_backend_vk_device_context {
|
@@ -12802,6 +12803,7 @@ struct ggml_backend_vk_device_context {
|
||||||
std::string pci_id;
|
std::string pci_id;
|
||||||
std::string id;
|
std::string id;
|
||||||
std::string uuid;
|
std::string uuid;
|
||||||
|
|
@ -57,7 +57,7 @@ index 7cfb14a54..a1c46d0b3 100644
|
||||||
int major;
|
int major;
|
||||||
int minor;
|
int minor;
|
||||||
int driver_major;
|
int driver_major;
|
||||||
@@ -12448,8 +12450,22 @@ void ggml_backend_vk_get_device_memory(ggml_backend_vk_device_context *ctx, size
|
@@ -12817,8 +12819,22 @@ void ggml_backend_vk_get_device_memory(ggml_backend_vk_device_context *ctx, size
|
||||||
vk::PhysicalDeviceMemoryProperties memprops = vkdev.getMemoryProperties();
|
vk::PhysicalDeviceMemoryProperties memprops = vkdev.getMemoryProperties();
|
||||||
vk::PhysicalDeviceProperties2 props2;
|
vk::PhysicalDeviceProperties2 props2;
|
||||||
vkdev.getProperties2(&props2);
|
vkdev.getProperties2(&props2);
|
||||||
|
|
@ -81,7 +81,7 @@ index 7cfb14a54..a1c46d0b3 100644
|
||||||
{
|
{
|
||||||
// Use vendor specific management libraries for best VRAM reporting if available
|
// Use vendor specific management libraries for best VRAM reporting if available
|
||||||
switch (props2.properties.vendorID) {
|
switch (props2.properties.vendorID) {
|
||||||
@@ -12477,8 +12493,8 @@ void ggml_backend_vk_get_device_memory(ggml_backend_vk_device_context *ctx, size
|
@@ -12846,8 +12862,8 @@ void ggml_backend_vk_get_device_memory(ggml_backend_vk_device_context *ctx, size
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -91,7 +91,7 @@ index 7cfb14a54..a1c46d0b3 100644
|
||||||
*total = 0;
|
*total = 0;
|
||||||
*free = 0;
|
*free = 0;
|
||||||
vk::PhysicalDeviceMemoryBudgetPropertiesEXT mem_budget_props;
|
vk::PhysicalDeviceMemoryBudgetPropertiesEXT mem_budget_props;
|
||||||
@@ -13089,7 +13105,6 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg,
|
@@ -13500,7 +13516,6 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg,
|
||||||
/* .reg = */ reg,
|
/* .reg = */ reg,
|
||||||
/* .context = */ ctx,
|
/* .context = */ ctx,
|
||||||
});
|
});
|
||||||
|
|
@ -99,7 +99,7 @@ index 7cfb14a54..a1c46d0b3 100644
|
||||||
// Gather additional information about the device
|
// Gather additional information about the device
|
||||||
int dev_idx = vk_instance.device_indices[i];
|
int dev_idx = vk_instance.device_indices[i];
|
||||||
vk::PhysicalDeviceProperties props1;
|
vk::PhysicalDeviceProperties props1;
|
||||||
@@ -13112,6 +13127,14 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg,
|
@@ -13523,6 +13538,14 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ctx->uuid = oss.str();
|
ctx->uuid = oss.str();
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,19 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "ggml.h"
|
||||||
|
#include "ggml-backend.h"
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// backend API
|
||||||
|
GGML_BACKEND_API ggml_backend_t ggml_backend_hexagon_init(void);
|
||||||
|
|
||||||
|
GGML_BACKEND_API bool ggml_backend_is_hexagon(ggml_backend_t backend);
|
||||||
|
|
||||||
|
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_hexagon_reg(void);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
@ -21,8 +21,7 @@ GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const c
|
||||||
GGML_BACKEND_API void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device, size_t * free, size_t * total);
|
GGML_BACKEND_API void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device, size_t * free, size_t * total);
|
||||||
|
|
||||||
GGML_BACKEND_API void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir,
|
GGML_BACKEND_API void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir,
|
||||||
size_t n_threads, size_t n_devices,
|
size_t n_threads, size_t n_devices, ggml_backend_dev_t * devices);
|
||||||
ggml_backend_dev_t * devices, size_t * free_mem, size_t * total_mem);
|
|
||||||
|
|
||||||
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_reg(void);
|
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_reg(void);
|
||||||
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_add_server(const char * endpoint);
|
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_add_server(const char * endpoint);
|
||||||
|
|
|
||||||
|
|
@ -577,6 +577,10 @@ extern "C" {
|
||||||
GGML_UNARY_OP_EXP,
|
GGML_UNARY_OP_EXP,
|
||||||
GGML_UNARY_OP_GELU_ERF,
|
GGML_UNARY_OP_GELU_ERF,
|
||||||
GGML_UNARY_OP_XIELU,
|
GGML_UNARY_OP_XIELU,
|
||||||
|
GGML_UNARY_OP_FLOOR,
|
||||||
|
GGML_UNARY_OP_CEIL,
|
||||||
|
GGML_UNARY_OP_ROUND,
|
||||||
|
GGML_UNARY_OP_TRUNC,
|
||||||
|
|
||||||
GGML_UNARY_OP_COUNT,
|
GGML_UNARY_OP_COUNT,
|
||||||
};
|
};
|
||||||
|
|
@ -1151,6 +1155,46 @@ extern "C" {
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a);
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_floor(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_floor_inplace(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_ceil(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_ceil_inplace(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_round(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_round_inplace(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Truncates the fractional part of each element in the tensor (towards zero).
|
||||||
|
* For example: trunc(3.7) = 3.0, trunc(-2.9) = -2.0
|
||||||
|
* Similar to std::trunc in C/C++.
|
||||||
|
*/
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_trunc(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_trunc_inplace(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
// xIELU activation function
|
// xIELU activation function
|
||||||
// x = x * (c_a(alpha_n) + c_b(alpha_p, beta) * sigmoid(beta * x)) + eps * (x > 0)
|
// x = x * (c_a(alpha_n) + c_b(alpha_p, beta) * sigmoid(beta * x)) + eps * (x > 0)
|
||||||
// where c_a = softplus and c_b(a, b) = softplus(a) + b are constraining functions
|
// where c_a = softplus and c_b(a, b) = softplus(a) + b are constraining functions
|
||||||
|
|
|
||||||
|
|
@ -310,6 +310,10 @@ function(ggml_add_cpu_backend_variant tag_name)
|
||||||
foreach (feat ${ARGN})
|
foreach (feat ${ARGN})
|
||||||
set(GGML_INTERNAL_${feat} ON)
|
set(GGML_INTERNAL_${feat} ON)
|
||||||
endforeach()
|
endforeach()
|
||||||
|
elseif (GGML_SYSTEM_ARCH STREQUAL "s390x")
|
||||||
|
foreach (feat ${ARGN})
|
||||||
|
set(GGML_INTERNAL_${feat} ON)
|
||||||
|
endforeach()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
ggml_add_cpu_backend_variant_impl(${tag_name})
|
ggml_add_cpu_backend_variant_impl(${tag_name})
|
||||||
|
|
@ -372,6 +376,14 @@ if (GGML_CPU_ALL_VARIANTS)
|
||||||
else()
|
else()
|
||||||
message(FATAL_ERROR "Unsupported PowerPC target OS: ${CMAKE_SYSTEM_NAME}")
|
message(FATAL_ERROR "Unsupported PowerPC target OS: ${CMAKE_SYSTEM_NAME}")
|
||||||
endif()
|
endif()
|
||||||
|
elseif (GGML_SYSTEM_ARCH STREQUAL "s390x")
|
||||||
|
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
|
||||||
|
ggml_add_cpu_backend_variant(s390x_z15 Z15 VXE)
|
||||||
|
# ggml_add_cpu_backend_variant(s390x_z16 Z16 VXE)
|
||||||
|
# ggml_add_cpu_backend_variant(s390x_z17 Z17 VXE)
|
||||||
|
else()
|
||||||
|
message(FATAL_ERROR "Unsupported s390x target OS: ${CMAKE_SYSTEM_NAME}")
|
||||||
|
endif()
|
||||||
else()
|
else()
|
||||||
message(FATAL_ERROR "GGML_CPU_ALL_VARIANTS not yet supported with ${GGML_SYSTEM_ARCH} on ${CMAKE_SYSTEM_NAME}")
|
message(FATAL_ERROR "GGML_CPU_ALL_VARIANTS not yet supported with ${GGML_SYSTEM_ARCH} on ${CMAKE_SYSTEM_NAME}")
|
||||||
endif()
|
endif()
|
||||||
|
|
@ -391,6 +403,7 @@ ggml_add_backend(Vulkan)
|
||||||
ggml_add_backend(WebGPU)
|
ggml_add_backend(WebGPU)
|
||||||
ggml_add_backend(zDNN)
|
ggml_add_backend(zDNN)
|
||||||
ggml_add_backend(OpenCL)
|
ggml_add_backend(OpenCL)
|
||||||
|
ggml_add_backend(Hexagon)
|
||||||
|
|
||||||
foreach (target ggml-base ggml)
|
foreach (target ggml-base ggml)
|
||||||
target_include_directories(${target} PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/../include> $<INSTALL_INTERFACE:include>)
|
target_include_directories(${target} PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/../include> $<INSTALL_INTERFACE:include>)
|
||||||
|
|
|
||||||
|
|
@ -603,6 +603,26 @@ static bool ggml_gallocr_is_allocated(ggml_gallocr_t galloc, struct ggml_tensor
|
||||||
return t->data != NULL || ggml_gallocr_hash_get(galloc, t)->allocated;
|
return t->data != NULL || ggml_gallocr_hash_get(galloc, t)->allocated;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// free the extra space at the end if the new tensor is smaller
|
||||||
|
static void ggml_gallocr_free_extra_space(ggml_gallocr_t galloc, struct ggml_tensor * node, struct ggml_tensor * parent) {
|
||||||
|
struct hash_node * hn = ggml_gallocr_hash_get(galloc, node);
|
||||||
|
struct hash_node * p_hn = ggml_gallocr_hash_get(galloc, parent);
|
||||||
|
|
||||||
|
size_t parent_size = ggml_backend_buft_get_alloc_size(galloc->bufts[p_hn->buffer_id], parent);
|
||||||
|
size_t node_size = ggml_backend_buft_get_alloc_size(galloc->bufts[hn->buffer_id], node);
|
||||||
|
|
||||||
|
GGML_ASSERT(parent_size >= node_size);
|
||||||
|
|
||||||
|
if (parent_size > node_size) {
|
||||||
|
struct ggml_dyn_tallocr * p_alloc = galloc->buf_tallocs[p_hn->buffer_id];
|
||||||
|
struct buffer_address p_addr = p_hn->addr;
|
||||||
|
p_addr.offset += node_size;
|
||||||
|
size_t extra_size = parent_size - node_size;
|
||||||
|
AT_PRINTF("freeing extra %zu bytes from parent %s for %s\n", extra_size, parent->name, node->name);
|
||||||
|
ggml_dyn_tallocr_free_tensor(p_alloc, p_addr, extra_size, parent);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor * node, int buffer_id) {
|
static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor * node, int buffer_id) {
|
||||||
GGML_ASSERT(buffer_id >= 0);
|
GGML_ASSERT(buffer_id >= 0);
|
||||||
struct hash_node * hn = ggml_gallocr_hash_get(galloc, node);
|
struct hash_node * hn = ggml_gallocr_hash_get(galloc, node);
|
||||||
|
|
@ -648,6 +668,7 @@ static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor
|
||||||
hn->addr = p_hn->addr;
|
hn->addr = p_hn->addr;
|
||||||
p_hn->allocated = false; // avoid freeing the parent
|
p_hn->allocated = false; // avoid freeing the parent
|
||||||
view_src_hn->allocated = false;
|
view_src_hn->allocated = false;
|
||||||
|
ggml_gallocr_free_extra_space(galloc, node, view_src);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -655,6 +676,7 @@ static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor
|
||||||
hn->buffer_id = p_hn->buffer_id;
|
hn->buffer_id = p_hn->buffer_id;
|
||||||
hn->addr = p_hn->addr;
|
hn->addr = p_hn->addr;
|
||||||
p_hn->allocated = false; // avoid freeing the parent
|
p_hn->allocated = false; // avoid freeing the parent
|
||||||
|
ggml_gallocr_free_extra_space(galloc, node, parent);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -57,6 +57,10 @@
|
||||||
#include "ggml-opencl.h"
|
#include "ggml-opencl.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#ifdef GGML_USE_HEXAGON
|
||||||
|
#include "ggml-hexagon.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
#ifdef GGML_USE_BLAS
|
#ifdef GGML_USE_BLAS
|
||||||
#include "ggml-blas.h"
|
#include "ggml-blas.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
@ -211,6 +215,9 @@ struct ggml_backend_registry {
|
||||||
#ifdef GGML_USE_OPENCL
|
#ifdef GGML_USE_OPENCL
|
||||||
register_backend(ggml_backend_opencl_reg());
|
register_backend(ggml_backend_opencl_reg());
|
||||||
#endif
|
#endif
|
||||||
|
#ifdef GGML_USE_HEXAGON
|
||||||
|
register_backend(ggml_backend_hexagon_reg());
|
||||||
|
#endif
|
||||||
#ifdef GGML_USE_CANN
|
#ifdef GGML_USE_CANN
|
||||||
register_backend(ggml_backend_cann_reg());
|
register_backend(ggml_backend_cann_reg());
|
||||||
#endif
|
#endif
|
||||||
|
|
@ -615,6 +622,7 @@ void ggml_backend_load_all_from_path(const char * dir_path) {
|
||||||
ggml_backend_load_best("sycl", silent, dir_path);
|
ggml_backend_load_best("sycl", silent, dir_path);
|
||||||
ggml_backend_load_best("vulkan", silent, dir_path);
|
ggml_backend_load_best("vulkan", silent, dir_path);
|
||||||
ggml_backend_load_best("opencl", silent, dir_path);
|
ggml_backend_load_best("opencl", silent, dir_path);
|
||||||
|
ggml_backend_load_best("hexagon", silent, dir_path);
|
||||||
ggml_backend_load_best("musa", silent, dir_path);
|
ggml_backend_load_best("musa", silent, dir_path);
|
||||||
ggml_backend_load_best("cpu", silent, dir_path);
|
ggml_backend_load_best("cpu", silent, dir_path);
|
||||||
// check the environment variable GGML_BACKEND_PATH to load an out-of-tree backend
|
// check the environment variable GGML_BACKEND_PATH to load an out-of-tree backend
|
||||||
|
|
|
||||||
|
|
@ -466,7 +466,12 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||||
list(APPEND ARCH_FLAGS "-march=${MARCH_STR}" -mabi=lp64d)
|
list(APPEND ARCH_FLAGS "-march=${MARCH_STR}" -mabi=lp64d)
|
||||||
elseif (GGML_SYSTEM_ARCH STREQUAL "s390x")
|
elseif (GGML_SYSTEM_ARCH STREQUAL "s390x")
|
||||||
message(STATUS "s390x detected")
|
message(STATUS "s390x detected")
|
||||||
list(APPEND GGML_CPU_SOURCES ggml-cpu/arch/s390/quants.c)
|
list(APPEND GGML_CPU_SOURCES
|
||||||
|
ggml-cpu/arch/s390/quants.c)
|
||||||
|
|
||||||
|
# for native compilation
|
||||||
|
if (GGML_NATIVE)
|
||||||
|
# check machine level to determine target
|
||||||
file(READ "/proc/cpuinfo" CPUINFO_CONTENTS)
|
file(READ "/proc/cpuinfo" CPUINFO_CONTENTS)
|
||||||
string(REGEX REPLACE "machine[ \t\r\n]*=[ \t\r\n]*([0-9]+)" "\\1" S390X_M ${CPUINFO_CONTENTS})
|
string(REGEX REPLACE "machine[ \t\r\n]*=[ \t\r\n]*([0-9]+)" "\\1" S390X_M ${CPUINFO_CONTENTS})
|
||||||
|
|
||||||
|
|
@ -487,8 +492,19 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||||
message(WARNING "Unknown target. If you are compiling for z14 and earlier, you might have to add -DGGML_VXE=OFF.")
|
message(WARNING "Unknown target. If you are compiling for z14 and earlier, you might have to add -DGGML_VXE=OFF.")
|
||||||
list(APPEND ARCH_FLAGS -march=native -mtune=native)
|
list(APPEND ARCH_FLAGS -march=native -mtune=native)
|
||||||
endif()
|
endif()
|
||||||
|
# for cross-compilation
|
||||||
|
elseif(GGML_CPU_ALL_VARIANTS)
|
||||||
|
# range through IBM z15 to z17
|
||||||
|
# NOTE: update when a new hardware level is released
|
||||||
|
foreach (ZHW RANGE 15 17)
|
||||||
|
if(DEFINED GGML_INTERNAL_Z${ZHW})
|
||||||
|
message(STATUS "z${ZHW} cross-compile target")
|
||||||
|
list(APPEND ARCH_FLAGS -march=z${ZHW})
|
||||||
|
endif()
|
||||||
|
endforeach()
|
||||||
|
endif()
|
||||||
|
|
||||||
if (GGML_VXE)
|
if (GGML_VXE OR GGML_INTERNAL_VXE)
|
||||||
message(STATUS "VX/VXE/VXE2 enabled")
|
message(STATUS "VX/VXE/VXE2 enabled")
|
||||||
list(APPEND ARCH_FLAGS -mvx -mzvector)
|
list(APPEND ARCH_FLAGS -mvx -mzvector)
|
||||||
list(APPEND ARCH_DEFINITIONS GGML_VXE)
|
list(APPEND ARCH_DEFINITIONS GGML_VXE)
|
||||||
|
|
|
||||||
|
|
@ -2186,6 +2186,10 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||||
case GGML_UNARY_OP_HARDSWISH:
|
case GGML_UNARY_OP_HARDSWISH:
|
||||||
case GGML_UNARY_OP_HARDSIGMOID:
|
case GGML_UNARY_OP_HARDSIGMOID:
|
||||||
case GGML_UNARY_OP_EXP:
|
case GGML_UNARY_OP_EXP:
|
||||||
|
case GGML_UNARY_OP_FLOOR:
|
||||||
|
case GGML_UNARY_OP_CEIL:
|
||||||
|
case GGML_UNARY_OP_ROUND:
|
||||||
|
case GGML_UNARY_OP_TRUNC:
|
||||||
{
|
{
|
||||||
n_tasks = 1;
|
n_tasks = 1;
|
||||||
} break;
|
} break;
|
||||||
|
|
@ -3569,13 +3573,17 @@ void ggml_cpu_init(void) {
|
||||||
#ifdef GGML_USE_OPENMP
|
#ifdef GGML_USE_OPENMP
|
||||||
//if (!getenv("OMP_WAIT_POLICY")) {
|
//if (!getenv("OMP_WAIT_POLICY")) {
|
||||||
// // set the wait policy to active, so that OpenMP threads don't sleep
|
// // set the wait policy to active, so that OpenMP threads don't sleep
|
||||||
// putenv("OMP_WAIT_POLICY=active");
|
// setenv("OMP_WAIT_POLICY", "active", 0)
|
||||||
//}
|
//}
|
||||||
|
|
||||||
if (!getenv("KMP_BLOCKTIME")) {
|
if (!getenv("KMP_BLOCKTIME")) {
|
||||||
// set the time to wait before sleeping a thread
|
// set the time to wait before sleeping a thread
|
||||||
// this is less aggressive than setting the wait policy to active, but should achieve similar results in most cases
|
// this is less aggressive than setting the wait policy to active, but should achieve similar results in most cases
|
||||||
putenv("KMP_BLOCKTIME=200"); // 200ms
|
#ifdef _WIN32
|
||||||
|
_putenv_s("KMP_BLOCKTIME", "200"); // 200ms
|
||||||
|
#else
|
||||||
|
setenv("KMP_BLOCKTIME", "200", 0); // 200ms
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -9033,6 +9033,22 @@ void ggml_compute_forward_unary(
|
||||||
{
|
{
|
||||||
ggml_compute_forward_exp(params, dst);
|
ggml_compute_forward_exp(params, dst);
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_UNARY_OP_FLOOR:
|
||||||
|
{
|
||||||
|
ggml_compute_forward_floor(params, dst);
|
||||||
|
} break;
|
||||||
|
case GGML_UNARY_OP_CEIL:
|
||||||
|
{
|
||||||
|
ggml_compute_forward_ceil(params, dst);
|
||||||
|
} break;
|
||||||
|
case GGML_UNARY_OP_ROUND:
|
||||||
|
{
|
||||||
|
ggml_compute_forward_round(params, dst);
|
||||||
|
} break;
|
||||||
|
case GGML_UNARY_OP_TRUNC:
|
||||||
|
{
|
||||||
|
ggml_compute_forward_trunc(params, dst);
|
||||||
|
} break;
|
||||||
case GGML_UNARY_OP_XIELU:
|
case GGML_UNARY_OP_XIELU:
|
||||||
{
|
{
|
||||||
ggml_compute_forward_xielu(params, dst);
|
ggml_compute_forward_xielu(params, dst);
|
||||||
|
|
|
||||||
|
|
@ -73,6 +73,22 @@ static inline float op_log(float x) {
|
||||||
return logf(x);
|
return logf(x);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static inline float op_floor(float x) {
|
||||||
|
return floorf(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline float op_ceil(float x) {
|
||||||
|
return ceilf(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline float op_round(float x) {
|
||||||
|
return roundf(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline float op_trunc(float x) {
|
||||||
|
return truncf(x);
|
||||||
|
}
|
||||||
|
|
||||||
template <float (*op)(float), typename src0_t, typename dst_t>
|
template <float (*op)(float), typename src0_t, typename dst_t>
|
||||||
static inline void vec_unary_op(int64_t n, dst_t * y, const src0_t * x) {
|
static inline void vec_unary_op(int64_t n, dst_t * y, const src0_t * x) {
|
||||||
constexpr auto src0_to_f32 = type_conversion_table<src0_t>::to_f32;
|
constexpr auto src0_to_f32 = type_conversion_table<src0_t>::to_f32;
|
||||||
|
|
@ -274,6 +290,22 @@ void ggml_compute_forward_log(const ggml_compute_params * params, ggml_tensor *
|
||||||
unary_op<op_log>(params, dst);
|
unary_op<op_log>(params, dst);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ggml_compute_forward_floor(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||||
|
unary_op<op_floor>(params, dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_compute_forward_ceil(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||||
|
unary_op<op_ceil>(params, dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_compute_forward_round(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||||
|
unary_op<op_round>(params, dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_compute_forward_trunc(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||||
|
unary_op<op_trunc>(params, dst);
|
||||||
|
}
|
||||||
|
|
||||||
void ggml_compute_forward_xielu(const ggml_compute_params * params, ggml_tensor * dst) {
|
void ggml_compute_forward_xielu(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||||
const float alpha_n = ggml_get_op_params_f32(dst, 1);
|
const float alpha_n = ggml_get_op_params_f32(dst, 1);
|
||||||
const float alpha_p = ggml_get_op_params_f32(dst, 2);
|
const float alpha_p = ggml_get_op_params_f32(dst, 2);
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,10 @@ void ggml_compute_forward_sqrt(const struct ggml_compute_params * params, struct
|
||||||
void ggml_compute_forward_sin(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_sin(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
void ggml_compute_forward_cos(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_cos(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
void ggml_compute_forward_log(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_log(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
|
void ggml_compute_forward_floor(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
|
void ggml_compute_forward_ceil(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
|
void ggml_compute_forward_round(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
|
void ggml_compute_forward_trunc(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
void ggml_compute_forward_xielu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_xielu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,81 @@
|
||||||
#include "argsort.cuh"
|
#include "argsort.cuh"
|
||||||
|
|
||||||
|
#ifdef GGML_CUDA_USE_CUB
|
||||||
|
# include <cub/cub.cuh>
|
||||||
|
using namespace cub;
|
||||||
|
#endif // GGML_CUDA_USE_CUB
|
||||||
|
|
||||||
|
static __global__ void init_indices(int * indices, const int ncols, const int nrows) {
|
||||||
|
const int col = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
const int row = blockIdx.y;
|
||||||
|
|
||||||
|
if (col < ncols && row < nrows) {
|
||||||
|
indices[row * ncols + col] = col;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static __global__ void init_offsets(int * offsets, const int ncols, const int nrows) {
|
||||||
|
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
if (idx <= nrows) {
|
||||||
|
offsets[idx] = idx * ncols;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifdef GGML_CUDA_USE_CUB
|
||||||
|
static void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
|
||||||
|
const float * x,
|
||||||
|
int * dst,
|
||||||
|
const int ncols,
|
||||||
|
const int nrows,
|
||||||
|
ggml_sort_order order,
|
||||||
|
cudaStream_t stream) {
|
||||||
|
ggml_cuda_pool_alloc<int> temp_indices_alloc(pool, ncols * nrows);
|
||||||
|
ggml_cuda_pool_alloc<float> temp_keys_alloc(pool, ncols * nrows);
|
||||||
|
ggml_cuda_pool_alloc<int> offsets_alloc(pool, nrows + 1);
|
||||||
|
|
||||||
|
int * temp_indices = temp_indices_alloc.get();
|
||||||
|
float * temp_keys = temp_keys_alloc.get();
|
||||||
|
int * d_offsets = offsets_alloc.get();
|
||||||
|
|
||||||
|
static const int block_size = 256;
|
||||||
|
const dim3 grid_size((ncols + block_size - 1) / block_size, nrows);
|
||||||
|
init_indices<<<grid_size, block_size, 0, stream>>>(temp_indices, ncols, nrows);
|
||||||
|
|
||||||
|
const dim3 offset_grid((nrows + block_size - 1) / block_size);
|
||||||
|
init_offsets<<<offset_grid, block_size, 0, stream>>>(d_offsets, ncols, nrows);
|
||||||
|
|
||||||
|
cudaMemcpyAsync(temp_keys, x, ncols * nrows * sizeof(float), cudaMemcpyDeviceToDevice, stream);
|
||||||
|
|
||||||
|
size_t temp_storage_bytes = 0;
|
||||||
|
|
||||||
|
if (order == GGML_SORT_ORDER_ASC) {
|
||||||
|
DeviceSegmentedRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||||
|
temp_indices, dst, // values (indices)
|
||||||
|
ncols * nrows, nrows, // num items, num segments
|
||||||
|
d_offsets, d_offsets + 1, 0, sizeof(float) * 8, // all bits
|
||||||
|
stream);
|
||||||
|
} else {
|
||||||
|
DeviceSegmentedRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices,
|
||||||
|
dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, 0,
|
||||||
|
sizeof(float) * 8, stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_cuda_pool_alloc<uint8_t> temp_storage_alloc(pool, temp_storage_bytes);
|
||||||
|
void * d_temp_storage = temp_storage_alloc.get();
|
||||||
|
|
||||||
|
if (order == GGML_SORT_ORDER_ASC) {
|
||||||
|
DeviceSegmentedRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst,
|
||||||
|
ncols * nrows, nrows, d_offsets, d_offsets + 1, 0, sizeof(float) * 8,
|
||||||
|
stream);
|
||||||
|
} else {
|
||||||
|
DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
|
||||||
|
temp_indices, dst, ncols * nrows, nrows, d_offsets, d_offsets + 1,
|
||||||
|
0, sizeof(float) * 8, stream);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif // GGML_CUDA_USE_CUB
|
||||||
|
|
||||||
|
// Bitonic sort implementation
|
||||||
template<typename T>
|
template<typename T>
|
||||||
static inline __device__ void ggml_cuda_swap(T & a, T & b) {
|
static inline __device__ void ggml_cuda_swap(T & a, T & b) {
|
||||||
T tmp = a;
|
T tmp = a;
|
||||||
|
|
@ -65,7 +141,12 @@ static int next_power_of_2(int x) {
|
||||||
return n;
|
return n;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, cudaStream_t stream) {
|
static void argsort_f32_i32_cuda_bitonic(const float * x,
|
||||||
|
int * dst,
|
||||||
|
const int ncols,
|
||||||
|
const int nrows,
|
||||||
|
ggml_sort_order order,
|
||||||
|
cudaStream_t stream) {
|
||||||
// bitonic sort requires ncols to be power of 2
|
// bitonic sort requires ncols to be power of 2
|
||||||
const int ncols_pad = next_power_of_2(ncols);
|
const int ncols_pad = next_power_of_2(ncols);
|
||||||
|
|
||||||
|
|
@ -77,9 +158,11 @@ static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, co
|
||||||
GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb);
|
GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb);
|
||||||
|
|
||||||
if (order == GGML_SORT_ORDER_ASC) {
|
if (order == GGML_SORT_ORDER_ASC) {
|
||||||
k_argsort_f32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
|
k_argsort_f32_i32<GGML_SORT_ORDER_ASC>
|
||||||
|
<<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
|
||||||
} else if (order == GGML_SORT_ORDER_DESC) {
|
} else if (order == GGML_SORT_ORDER_DESC) {
|
||||||
k_argsort_f32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
|
k_argsort_f32_i32<GGML_SORT_ORDER_DESC>
|
||||||
|
<<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
|
||||||
} else {
|
} else {
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
}
|
}
|
||||||
|
|
@ -197,6 +280,19 @@ void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
if (src0->type == GGML_TYPE_I32) {
|
if (src0->type == GGML_TYPE_I32) {
|
||||||
argsort_i32_i32_cuda((const int32_t *)src0_d, (int *)dst_d, ncols, nrows, order, stream);
|
argsort_i32_i32_cuda((const int32_t *)src0_d, (int *)dst_d, ncols, nrows, order, stream);
|
||||||
} else {
|
} else {
|
||||||
argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, order, stream);
|
#ifdef GGML_CUDA_USE_CUB
|
||||||
|
const int ncols_pad = next_power_of_2(ncols);
|
||||||
|
const size_t shared_mem = ncols_pad * sizeof(int);
|
||||||
|
const size_t max_shared_mem = ggml_cuda_info().devices[ggml_cuda_get_device()].smpb;
|
||||||
|
|
||||||
|
if (shared_mem > max_shared_mem || ncols > 1024) {
|
||||||
|
ggml_cuda_pool & pool = ctx.pool();
|
||||||
|
argsort_f32_i32_cuda_cub(pool, src0_d, (int *) dst_d, ncols, nrows, order, stream);
|
||||||
|
} else {
|
||||||
|
argsort_f32_i32_cuda_bitonic(src0_d, (int *) dst_d, ncols, nrows, order, stream);
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
argsort_f32_i32_cuda_bitonic(src0_d, (int *) dst_d, ncols, nrows, order, stream);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -272,7 +272,7 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
|
||||||
const uint3 ne12 = init_fastdiv_values((uint32_t) cne1[2]);
|
const uint3 ne12 = init_fastdiv_values((uint32_t) cne1[2]);
|
||||||
const uint3 ne13 = init_fastdiv_values((uint32_t) cne1[3]);
|
const uint3 ne13 = init_fastdiv_values((uint32_t) cne1[3]);
|
||||||
|
|
||||||
if (block_nums.z > 65535) {
|
if (block_nums.z > 65535 || block_nums.y > 65535) {
|
||||||
int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size;
|
int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size;
|
||||||
const uint3 prod_012 = init_fastdiv_values((uint32_t) (ne0 * ne1 * ne2));
|
const uint3 prod_012 = init_fastdiv_values((uint32_t) (ne0 * ne1 * ne2));
|
||||||
const uint3 prod_01 = init_fastdiv_values((uint32_t) (ne0 * ne1));
|
const uint3 prod_01 = init_fastdiv_values((uint32_t) (ne0 * ne1));
|
||||||
|
|
|
||||||
|
|
@ -982,13 +982,6 @@ struct ggml_cuda_graph {
|
||||||
bool disable_due_to_failed_graph_capture = false;
|
bool disable_due_to_failed_graph_capture = false;
|
||||||
int number_consecutive_updates = 0;
|
int number_consecutive_updates = 0;
|
||||||
std::vector<ggml_graph_node_properties> ggml_graph_properties;
|
std::vector<ggml_graph_node_properties> ggml_graph_properties;
|
||||||
bool use_cpy_indirection = false;
|
|
||||||
std::vector<char *> cpy_dest_ptrs;
|
|
||||||
char ** dest_ptrs_d;
|
|
||||||
int dest_ptrs_size = 0;
|
|
||||||
// Index to allow each cpy kernel to be aware of it's position within the graph
|
|
||||||
// relative to other cpy nodes.
|
|
||||||
int graph_cpynode_index = -1;
|
|
||||||
#endif
|
#endif
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -8,18 +8,16 @@
|
||||||
typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
|
typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
|
||||||
|
|
||||||
template <cpy_kernel_t cpy_1>
|
template <cpy_kernel_t cpy_1>
|
||||||
static __global__ void cpy_flt(const char * cx, char * cdst_direct, const int ne,
|
static __global__ void cpy_flt(const char * cx, char * cdst, const int ne,
|
||||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||||
const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
|
const int nb12, const int nb13) {
|
||||||
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
|
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||||
|
|
||||||
if (i >= ne) {
|
if (i >= ne) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct;
|
|
||||||
|
|
||||||
// determine indices i03/i13, i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
|
// determine indices i03/i13, i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
|
||||||
// then combine those indices with the corresponding byte offsets to get the total offsets
|
// then combine those indices with the corresponding byte offsets to get the total offsets
|
||||||
const int64_t i03 = i/(ne00 * ne01 * ne02);
|
const int64_t i03 = i/(ne00 * ne01 * ne02);
|
||||||
|
|
@ -63,18 +61,16 @@ static __device__ void cpy_blck_q_f32(const char * cxi, char * cdsti) {
|
||||||
}
|
}
|
||||||
|
|
||||||
template <cpy_kernel_t cpy_blck, int qk>
|
template <cpy_kernel_t cpy_blck, int qk>
|
||||||
static __global__ void cpy_f32_q(const char * cx, char * cdst_direct, const int ne,
|
static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
|
||||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||||
const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
|
const int nb12, const int nb13) {
|
||||||
const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
|
const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
|
||||||
|
|
||||||
if (i >= ne) {
|
if (i >= ne) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct;
|
|
||||||
|
|
||||||
const int i03 = i/(ne00 * ne01 * ne02);
|
const int i03 = i/(ne00 * ne01 * ne02);
|
||||||
const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
|
const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
|
||||||
const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
|
const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
|
||||||
|
|
@ -91,18 +87,16 @@ static __global__ void cpy_f32_q(const char * cx, char * cdst_direct, const int
|
||||||
}
|
}
|
||||||
|
|
||||||
template <cpy_kernel_t cpy_blck, int qk>
|
template <cpy_kernel_t cpy_blck, int qk>
|
||||||
static __global__ void cpy_q_f32(const char * cx, char * cdst_direct, const int ne,
|
static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne,
|
||||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||||
const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
|
const int nb12, const int nb13) {
|
||||||
const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
|
const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
|
||||||
|
|
||||||
if (i >= ne) {
|
if (i >= ne) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct;
|
|
||||||
|
|
||||||
const int i03 = i/(ne00 * ne01 * ne02);
|
const int i03 = i/(ne00 * ne01 * ne02);
|
||||||
const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
|
const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
|
||||||
const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
|
const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
|
||||||
|
|
@ -118,67 +112,47 @@ static __global__ void cpy_q_f32(const char * cx, char * cdst_direct, const int
|
||||||
cpy_blck(cx + x_offset, cdst + dst_offset);
|
cpy_blck(cx + x_offset, cdst + dst_offset);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Copy destination pointers to GPU to be available when pointer indirection is in use
|
|
||||||
|
|
||||||
void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_dest_ptrs, const int host_dest_ptrs_size, cudaStream_t stream) {
|
|
||||||
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS)
|
|
||||||
if (cuda_graph->dest_ptrs_size < host_dest_ptrs_size) { // (re-)allocate GPU memory for destination pointers
|
|
||||||
CUDA_CHECK(cudaStreamSynchronize(stream));
|
|
||||||
if (cuda_graph->dest_ptrs_d != nullptr) {
|
|
||||||
CUDA_CHECK(cudaFree(cuda_graph->dest_ptrs_d));
|
|
||||||
}
|
|
||||||
CUDA_CHECK(cudaMalloc(&cuda_graph->dest_ptrs_d, host_dest_ptrs_size*sizeof(char *)));
|
|
||||||
cuda_graph->dest_ptrs_size = host_dest_ptrs_size;
|
|
||||||
}
|
|
||||||
// copy destination pointers to GPU
|
|
||||||
CUDA_CHECK(cudaMemcpyAsync(cuda_graph->dest_ptrs_d, host_dest_ptrs, host_dest_ptrs_size*sizeof(char *), cudaMemcpyHostToDevice, stream));
|
|
||||||
cuda_graph->graph_cpynode_index = 0; // reset index
|
|
||||||
#else
|
|
||||||
GGML_UNUSED_VARS(cuda_graph, host_dest_ptrs, host_dest_ptrs_size, stream);
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename src_t, typename dst_t>
|
template<typename src_t, typename dst_t>
|
||||||
static void ggml_cpy_flt_cuda(
|
static void ggml_cpy_flt_cuda(
|
||||||
const char * cx, char * cdst, const int ne,
|
const char * cx, char * cdst, const int ne,
|
||||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||||
|
|
||||||
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
||||||
cpy_flt<cpy_1_flt<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
cpy_flt<cpy_1_flt<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
||||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_cpy_f32_q8_0_cuda(
|
static void ggml_cpy_f32_q8_0_cuda(
|
||||||
const char * cx, char * cdst, const int ne,
|
const char * cx, char * cdst, const int ne,
|
||||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||||
|
|
||||||
GGML_ASSERT(ne % QK8_0 == 0);
|
GGML_ASSERT(ne % QK8_0 == 0);
|
||||||
const int num_blocks = ne / QK8_0;
|
const int num_blocks = ne / QK8_0;
|
||||||
cpy_f32_q<cpy_blck_f32_q8_0, QK8_0><<<num_blocks, 1, 0, stream>>>
|
cpy_f32_q<cpy_blck_f32_q8_0, QK8_0><<<num_blocks, 1, 0, stream>>>
|
||||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_cpy_q8_0_f32_cuda(
|
static void ggml_cpy_q8_0_f32_cuda(
|
||||||
const char * cx, char * cdst, const int ne,
|
const char * cx, char * cdst, const int ne,
|
||||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||||
|
|
||||||
const int num_blocks = ne;
|
const int num_blocks = ne;
|
||||||
cpy_q_f32<cpy_blck_q8_0_f32, QK8_0><<<num_blocks, 1, 0, stream>>>
|
cpy_q_f32<cpy_blck_q8_0_f32, QK8_0><<<num_blocks, 1, 0, stream>>>
|
||||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_cpy_f32_q4_0_cuda(
|
static void ggml_cpy_f32_q4_0_cuda(
|
||||||
const char * cx, char * cdst, const int ne,
|
const char * cx, char * cdst, const int ne,
|
||||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||||
|
|
||||||
GGML_ASSERT(ne % QK4_0 == 0);
|
GGML_ASSERT(ne % QK4_0 == 0);
|
||||||
const int num_blocks = ne / QK4_0;
|
const int num_blocks = ne / QK4_0;
|
||||||
cpy_f32_q<cpy_blck_f32_q4_0, QK4_0><<<num_blocks, 1, 0, stream>>>
|
cpy_f32_q<cpy_blck_f32_q4_0, QK4_0><<<num_blocks, 1, 0, stream>>>
|
||||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_cpy_q4_0_f32_cuda(
|
static void ggml_cpy_q4_0_f32_cuda(
|
||||||
|
|
@ -187,22 +161,22 @@ static void ggml_cpy_q4_0_f32_cuda(
|
||||||
const int nb00, const int nb01, const int nb02,
|
const int nb00, const int nb01, const int nb02,
|
||||||
const int nb03, const int ne10, const int ne11, const int ne12,
|
const int nb03, const int ne10, const int ne11, const int ne12,
|
||||||
const int nb10, const int nb11, const int nb12, const int nb13,
|
const int nb10, const int nb11, const int nb12, const int nb13,
|
||||||
cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
cudaStream_t stream) {
|
||||||
const int num_blocks = ne;
|
const int num_blocks = ne;
|
||||||
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0><<<num_blocks, 1, 0, stream>>>(
|
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0><<<num_blocks, 1, 0, stream>>>(
|
||||||
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_cpy_f32_q4_1_cuda(
|
static void ggml_cpy_f32_q4_1_cuda(
|
||||||
const char * cx, char * cdst, const int ne,
|
const char * cx, char * cdst, const int ne,
|
||||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||||
|
|
||||||
GGML_ASSERT(ne % QK4_1 == 0);
|
GGML_ASSERT(ne % QK4_1 == 0);
|
||||||
const int num_blocks = ne / QK4_1;
|
const int num_blocks = ne / QK4_1;
|
||||||
cpy_f32_q<cpy_blck_f32_q4_1, QK4_1><<<num_blocks, 1, 0, stream>>>
|
cpy_f32_q<cpy_blck_f32_q4_1, QK4_1><<<num_blocks, 1, 0, stream>>>
|
||||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_cpy_q4_1_f32_cuda(
|
static void ggml_cpy_q4_1_f32_cuda(
|
||||||
|
|
@ -211,22 +185,22 @@ static void ggml_cpy_q4_1_f32_cuda(
|
||||||
const int nb00, const int nb01, const int nb02,
|
const int nb00, const int nb01, const int nb02,
|
||||||
const int nb03, const int ne10, const int ne11, const int ne12,
|
const int nb03, const int ne10, const int ne11, const int ne12,
|
||||||
const int nb10, const int nb11, const int nb12, const int nb13,
|
const int nb10, const int nb11, const int nb12, const int nb13,
|
||||||
cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
cudaStream_t stream) {
|
||||||
const int num_blocks = ne;
|
const int num_blocks = ne;
|
||||||
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1><<<num_blocks, 1, 0, stream>>>(
|
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1><<<num_blocks, 1, 0, stream>>>(
|
||||||
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_cpy_f32_q5_0_cuda(
|
static void ggml_cpy_f32_q5_0_cuda(
|
||||||
const char * cx, char * cdst, const int ne,
|
const char * cx, char * cdst, const int ne,
|
||||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||||
|
|
||||||
GGML_ASSERT(ne % QK5_0 == 0);
|
GGML_ASSERT(ne % QK5_0 == 0);
|
||||||
const int num_blocks = ne / QK5_0;
|
const int num_blocks = ne / QK5_0;
|
||||||
cpy_f32_q<cpy_blck_f32_q5_0, QK5_0><<<num_blocks, 1, 0, stream>>>
|
cpy_f32_q<cpy_blck_f32_q5_0, QK5_0><<<num_blocks, 1, 0, stream>>>
|
||||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_cpy_q5_0_f32_cuda(
|
static void ggml_cpy_q5_0_f32_cuda(
|
||||||
|
|
@ -235,22 +209,22 @@ static void ggml_cpy_q5_0_f32_cuda(
|
||||||
const int nb00, const int nb01, const int nb02,
|
const int nb00, const int nb01, const int nb02,
|
||||||
const int nb03, const int ne10, const int ne11, const int ne12,
|
const int nb03, const int ne10, const int ne11, const int ne12,
|
||||||
const int nb10, const int nb11, const int nb12, const int nb13,
|
const int nb10, const int nb11, const int nb12, const int nb13,
|
||||||
cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
cudaStream_t stream) {
|
||||||
const int num_blocks = ne;
|
const int num_blocks = ne;
|
||||||
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0><<<num_blocks, 1, 0, stream>>>(
|
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0><<<num_blocks, 1, 0, stream>>>(
|
||||||
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_cpy_f32_q5_1_cuda(
|
static void ggml_cpy_f32_q5_1_cuda(
|
||||||
const char * cx, char * cdst, const int ne,
|
const char * cx, char * cdst, const int ne,
|
||||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||||
|
|
||||||
GGML_ASSERT(ne % QK5_1 == 0);
|
GGML_ASSERT(ne % QK5_1 == 0);
|
||||||
const int num_blocks = ne / QK5_1;
|
const int num_blocks = ne / QK5_1;
|
||||||
cpy_f32_q<cpy_blck_f32_q5_1, QK5_1><<<num_blocks, 1, 0, stream>>>
|
cpy_f32_q<cpy_blck_f32_q5_1, QK5_1><<<num_blocks, 1, 0, stream>>>
|
||||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_cpy_q5_1_f32_cuda(
|
static void ggml_cpy_q5_1_f32_cuda(
|
||||||
|
|
@ -259,30 +233,29 @@ static void ggml_cpy_q5_1_f32_cuda(
|
||||||
const int nb00, const int nb01, const int nb02,
|
const int nb00, const int nb01, const int nb02,
|
||||||
const int nb03, const int ne10, const int ne11, const int ne12,
|
const int nb03, const int ne10, const int ne11, const int ne12,
|
||||||
const int nb10, const int nb11, const int nb12, const int nb13,
|
const int nb10, const int nb11, const int nb12, const int nb13,
|
||||||
cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
cudaStream_t stream) {
|
||||||
const int num_blocks = ne;
|
const int num_blocks = ne;
|
||||||
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1><<<num_blocks, 1, 0, stream>>>(
|
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1><<<num_blocks, 1, 0, stream>>>(
|
||||||
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_cpy_f32_iq4_nl_cuda(
|
static void ggml_cpy_f32_iq4_nl_cuda(
|
||||||
const char * cx, char * cdst, const int ne,
|
const char * cx, char * cdst, const int ne,
|
||||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||||
|
|
||||||
GGML_ASSERT(ne % QK4_NL == 0);
|
GGML_ASSERT(ne % QK4_NL == 0);
|
||||||
const int num_blocks = ne / QK4_NL;
|
const int num_blocks = ne / QK4_NL;
|
||||||
cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL><<<num_blocks, 1, 0, stream>>>
|
cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL><<<num_blocks, 1, 0, stream>>>
|
||||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <cpy_kernel_t cpy_1>
|
template <cpy_kernel_t cpy_1>
|
||||||
static __global__ void cpy_i32_i32(
|
static __global__ void cpy_i32_i32(
|
||||||
const char *cx, char *cdst, const int ne,
|
const char *cx, char *cdst, const int ne,
|
||||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
|
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
|
||||||
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13,
|
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||||
cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
|
||||||
|
|
||||||
const int64_t i = blockDim.x * blockIdx.x + threadIdx.x;
|
const int64_t i = blockDim.x * blockIdx.x + threadIdx.x;
|
||||||
|
|
||||||
|
|
@ -302,23 +275,20 @@ static __global__ void cpy_i32_i32(
|
||||||
const int64_t i10 = i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11 - i11 * ne10;
|
const int64_t i10 = i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11 - i11 * ne10;
|
||||||
const int64_t dst_offset = i10 * nb10 + i11 * nb11 + i12 * nb12 + i13 * nb13;
|
const int64_t dst_offset = i10 * nb10 + i11 * nb11 + i12 * nb12 + i13 * nb13;
|
||||||
|
|
||||||
char * cdst_ptr = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index] : cdst;
|
cpy_1(cx + x_offset, cdst + dst_offset);
|
||||||
cpy_1(cx + x_offset, cdst_ptr + dst_offset);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
static void ggml_cpy_i32_i32_cuda(
|
static void ggml_cpy_i32_i32_cuda(
|
||||||
const char * cx, char * cdst, const int ne,
|
const char * cx, char * cdst, const int ne,
|
||||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
|
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
|
||||||
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13,
|
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||||
cudaStream_t stream, char ** cdst_indirect, int graph_cpynode_index) {
|
|
||||||
|
|
||||||
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
||||||
cpy_i32_i32<cpy_1_i32_i32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
cpy_i32_i32<cpy_1_i32_i32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
||||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, stream, cdst_indirect, graph_cpynode_index);
|
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection_for_this_node) {
|
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) {
|
||||||
const int64_t ne = ggml_nelements(src0);
|
const int64_t ne = ggml_nelements(src0);
|
||||||
GGML_ASSERT(ne == ggml_nelements(src1));
|
GGML_ASSERT(ne == ggml_nelements(src1));
|
||||||
|
|
||||||
|
|
@ -352,16 +322,6 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
||||||
char * src0_ddc = (char *) src0->data;
|
char * src0_ddc = (char *) src0->data;
|
||||||
char * src1_ddc = (char *) src1->data;
|
char * src1_ddc = (char *) src1->data;
|
||||||
|
|
||||||
char ** dest_ptrs_d = nullptr;
|
|
||||||
int graph_cpynode_index = -1;
|
|
||||||
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS)
|
|
||||||
if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection_for_this_node) {
|
|
||||||
dest_ptrs_d = ctx.cuda_graph->dest_ptrs_d;
|
|
||||||
graph_cpynode_index = ctx.cuda_graph->graph_cpynode_index;
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
GGML_UNUSED(disable_indirection_for_this_node);
|
|
||||||
#endif
|
|
||||||
if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
|
if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
|
||||||
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
|
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
|
||||||
#if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY)
|
#if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY)
|
||||||
|
|
@ -370,136 +330,65 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
||||||
} else
|
} else
|
||||||
#endif // GGML_USE_MUSA && GGML_MUSA_MUDNN_COPY
|
#endif // GGML_USE_MUSA && GGML_MUSA_MUDNN_COPY
|
||||||
{
|
{
|
||||||
if (src0->type == GGML_TYPE_F32) {
|
|
||||||
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
|
||||||
} else {
|
|
||||||
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
|
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
|
||||||
}
|
}
|
||||||
}
|
|
||||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
||||||
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
|
||||||
ggml_cpy_flt_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
ggml_cpy_flt_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
|
||||||
ggml_cpy_flt_cuda<float, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
ggml_cpy_flt_cuda<float, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
|
||||||
ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
|
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
|
||||||
ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
|
||||||
ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
} else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
|
} else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
|
||||||
ggml_cpy_q4_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
|
ggml_cpy_q4_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
|
||||||
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
|
||||||
ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
} else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
|
} else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
|
||||||
ggml_cpy_q4_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
|
ggml_cpy_q4_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
|
||||||
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
|
||||||
ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
} else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
|
} else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
|
||||||
ggml_cpy_q5_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
|
ggml_cpy_q5_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
|
||||||
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
|
||||||
ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
|
||||||
ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
|
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
|
||||||
ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
||||||
ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
|
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
|
||||||
ggml_cpy_flt_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
ggml_cpy_flt_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
||||||
ggml_cpy_flt_cuda<half, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
ggml_cpy_flt_cuda<half, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) {
|
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) {
|
||||||
ggml_cpy_i32_i32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
// TODO consider converting to template
|
||||||
|
ggml_cpy_i32_i32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
|
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
|
||||||
ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
|
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
|
||||||
ggml_cpy_flt_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
ggml_cpy_flt_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
|
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
|
||||||
ggml_cpy_flt_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
ggml_cpy_flt_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) {
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) {
|
||||||
ggml_cpy_flt_cuda<float, int32_t> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
ggml_cpy_flt_cuda<float, int32_t> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) {
|
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) {
|
||||||
ggml_cpy_flt_cuda<int32_t, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
ggml_cpy_flt_cuda<int32_t, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
} else {
|
} else {
|
||||||
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
|
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
|
||||||
ggml_type_name(src0->type), ggml_type_name(src1->type));
|
ggml_type_name(src0->type), ggml_type_name(src1->type));
|
||||||
}
|
}
|
||||||
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS)
|
|
||||||
if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection_for_this_node) {
|
|
||||||
ctx.cuda_graph->graph_cpynode_index = graph_cpynode_index;
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
GGML_UNUSED(disable_indirection_for_this_node);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
const ggml_tensor * src0 = dst->src[0];
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
bool disable_indirection = true;
|
ggml_cuda_cpy(ctx, src0, dst);
|
||||||
ggml_cuda_cpy(ctx, src0, dst, disable_indirection);
|
|
||||||
}
|
|
||||||
|
|
||||||
void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
|
|
||||||
if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
|
|
||||||
// Prioritize CUDA graph compatibility over direct memory copy optimization.
|
|
||||||
// Using copy kernels here maintains graph indirection support, preventing performance regression from disabled CUDA graphs.
|
|
||||||
if (src0->type == GGML_TYPE_F32) {
|
|
||||||
return (void*) cpy_flt<cpy_1_flt<float, float>>;
|
|
||||||
} else {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
|
||||||
return (void*) cpy_flt<cpy_1_flt<float, float>>;
|
|
||||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
|
|
||||||
return (void*) cpy_flt<cpy_1_flt<float, nv_bfloat16>>;
|
|
||||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
|
|
||||||
return (void*) cpy_flt<cpy_1_flt<float, half>>;
|
|
||||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
|
|
||||||
return (void*) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>;
|
|
||||||
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
|
|
||||||
return (void*) cpy_q_f32<cpy_blck_q8_0_f32, QK8_0>;
|
|
||||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
|
|
||||||
return (void*) cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>;
|
|
||||||
} else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
|
|
||||||
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0>;
|
|
||||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
|
|
||||||
return (void*) cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>;
|
|
||||||
} else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
|
|
||||||
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1>;
|
|
||||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
|
|
||||||
return (void*) cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>;
|
|
||||||
} else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
|
|
||||||
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0>;
|
|
||||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
|
|
||||||
return (void*) cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>;
|
|
||||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
|
|
||||||
return (void*) cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>;
|
|
||||||
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
|
|
||||||
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1>;
|
|
||||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
|
||||||
return (void*) cpy_flt<cpy_1_flt<half, half>>;
|
|
||||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
|
|
||||||
return (void*) cpy_flt<cpy_1_flt<half, nv_bfloat16>>;
|
|
||||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
|
||||||
return (void*) cpy_flt<cpy_1_flt<half, float>>;
|
|
||||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
|
|
||||||
return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, half>>;
|
|
||||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
|
|
||||||
return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, nv_bfloat16>>;
|
|
||||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
|
|
||||||
return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, float>>;
|
|
||||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) {
|
|
||||||
return (void*) cpy_flt<cpy_1_flt<float, int32_t>>;
|
|
||||||
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) {
|
|
||||||
return (void*) cpy_flt<cpy_1_flt<int32_t, float>>;
|
|
||||||
} else {
|
|
||||||
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
|
|
||||||
ggml_type_name(src0->type), ggml_type_name(src1->type));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -2,10 +2,6 @@
|
||||||
|
|
||||||
#define CUDA_CPY_BLOCK_SIZE 64
|
#define CUDA_CPY_BLOCK_SIZE 64
|
||||||
|
|
||||||
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection = false);
|
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1);
|
||||||
|
|
||||||
void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1);
|
|
||||||
|
|
||||||
void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_dest_ptrs, const int host_dest_ptrs_size, cudaStream_t stream);
|
|
||||||
|
|
|
||||||
|
|
@ -895,6 +895,7 @@ void launch_fattn(
|
||||||
const dim3 block_dim(warp_size, nwarps, 1);
|
const dim3 block_dim(warp_size, nwarps, 1);
|
||||||
int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy.
|
int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy.
|
||||||
CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));
|
CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));
|
||||||
|
GGML_ASSERT(max_blocks_per_sm > 0);
|
||||||
int parallel_blocks = max_blocks_per_sm;
|
int parallel_blocks = max_blocks_per_sm;
|
||||||
|
|
||||||
dim3 blocks_num;
|
dim3 blocks_num;
|
||||||
|
|
|
||||||
|
|
@ -516,8 +516,8 @@ void ggml_cuda_flash_attn_ext_vec_case_impl(ggml_backend_cuda_context & ctx, ggm
|
||||||
const int nthreads = ggml_cuda_fattn_vec_get_nthreads_host(cc);
|
const int nthreads = ggml_cuda_fattn_vec_get_nthreads_host(cc);
|
||||||
const int nwarps = nthreads / WARP_SIZE;
|
const int nwarps = nthreads / WARP_SIZE;
|
||||||
fattn_kernel_t fattn_kernel = flash_attn_ext_vec<D, cols_per_block, type_K, type_V, use_logit_softcap>;
|
fattn_kernel_t fattn_kernel = flash_attn_ext_vec<D, cols_per_block, type_K, type_V, use_logit_softcap>;
|
||||||
constexpr bool need_f16_K = false;
|
const bool need_f16_K = type_K == GGML_TYPE_F16;
|
||||||
constexpr bool need_f16_V = false;
|
const bool need_f16_V = type_V == GGML_TYPE_F16;
|
||||||
constexpr size_t nbytes_shared = 0;
|
constexpr size_t nbytes_shared = 0;
|
||||||
launch_fattn<D, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
|
launch_fattn<D, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
|
||||||
}
|
}
|
||||||
|
|
@ -526,11 +526,6 @@ template <int D, ggml_type type_K, ggml_type type_V>
|
||||||
void ggml_cuda_flash_attn_ext_vec_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_flash_attn_ext_vec_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
const ggml_tensor * KQV = dst;
|
const ggml_tensor * KQV = dst;
|
||||||
const ggml_tensor * Q = dst->src[0];
|
const ggml_tensor * Q = dst->src[0];
|
||||||
const ggml_tensor * K = dst->src[1];
|
|
||||||
const ggml_tensor * V = dst->src[2];
|
|
||||||
|
|
||||||
GGML_ASSERT(K->type == type_K);
|
|
||||||
GGML_ASSERT(V->type == type_V);
|
|
||||||
|
|
||||||
float logit_softcap;
|
float logit_softcap;
|
||||||
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
||||||
|
|
|
||||||
|
|
@ -117,10 +117,14 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
|
||||||
}
|
}
|
||||||
|
|
||||||
#define FATTN_VEC_CASE(D, type_K, type_V) \
|
#define FATTN_VEC_CASE(D, type_K, type_V) \
|
||||||
if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \
|
{ \
|
||||||
|
const bool type_K_okay = K->type == (type_K) || (K->type == GGML_TYPE_F32 && (type_K) == GGML_TYPE_F16); \
|
||||||
|
const bool type_V_okay = V->type == (type_V) || (V->type == GGML_TYPE_F32 && (type_V) == GGML_TYPE_F16); \
|
||||||
|
if (Q->ne[0] == (D) && type_K_okay && type_V_okay) { \
|
||||||
ggml_cuda_flash_attn_ext_vec_case<D, type_K, type_V>(ctx, dst); \
|
ggml_cuda_flash_attn_ext_vec_case<D, type_K, type_V>(ctx, dst); \
|
||||||
return; \
|
return; \
|
||||||
} \
|
} \
|
||||||
|
} \
|
||||||
|
|
||||||
#define FATTN_VEC_CASES_ALL_D(type_K, type_V) \
|
#define FATTN_VEC_CASES_ALL_D(type_K, type_V) \
|
||||||
FATTN_VEC_CASE( 64, type_K, type_V) \
|
FATTN_VEC_CASE( 64, type_K, type_V) \
|
||||||
|
|
@ -247,6 +251,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
||||||
#endif // GGML_CUDA_FA_ALL_QUANTS
|
#endif // GGML_CUDA_FA_ALL_QUANTS
|
||||||
|
|
||||||
switch (K->type) {
|
switch (K->type) {
|
||||||
|
case GGML_TYPE_F32:
|
||||||
case GGML_TYPE_F16:
|
case GGML_TYPE_F16:
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_Q4_1:
|
case GGML_TYPE_Q4_1:
|
||||||
|
|
@ -272,7 +277,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
||||||
// If Turing tensor cores available, use them:
|
// If Turing tensor cores available, use them:
|
||||||
if (turing_mma_available(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40) {
|
if (turing_mma_available(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40) {
|
||||||
if (can_use_vector_kernel) {
|
if (can_use_vector_kernel) {
|
||||||
if (K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16) {
|
if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
|
||||||
if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne[1] == 1 && Q->ne[3] == 1 && !(gqa_ratio > 4 && K->ne[1] >= 8192)) {
|
if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne[1] == 1 && Q->ne[3] == 1 && !(gqa_ratio > 4 && K->ne[1] >= 8192)) {
|
||||||
return BEST_FATTN_KERNEL_VEC;
|
return BEST_FATTN_KERNEL_VEC;
|
||||||
}
|
}
|
||||||
|
|
@ -305,7 +310,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
||||||
|
|
||||||
// If there are no tensor cores available, use the generic tile kernel:
|
// If there are no tensor cores available, use the generic tile kernel:
|
||||||
if (can_use_vector_kernel) {
|
if (can_use_vector_kernel) {
|
||||||
if (K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16) {
|
if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
|
||||||
if (Q->ne[1] == 1) {
|
if (Q->ne[1] == 1) {
|
||||||
if (!gqa_opt_applies) {
|
if (!gqa_opt_applies) {
|
||||||
return BEST_FATTN_KERNEL_VEC;
|
return BEST_FATTN_KERNEL_VEC;
|
||||||
|
|
|
||||||
|
|
@ -2774,11 +2774,10 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef USE_CUDA_GRAPH
|
#ifdef USE_CUDA_GRAPH
|
||||||
static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
|
static bool check_node_graph_compatibility(ggml_cgraph * cgraph,
|
||||||
int batch_size, bool use_cuda_graph) {
|
int batch_size, bool use_cuda_graph) {
|
||||||
|
|
||||||
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
|
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
|
||||||
cuda_ctx->cuda_graph->cpy_dest_ptrs.clear();
|
|
||||||
|
|
||||||
const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected";
|
const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected";
|
||||||
const std::string gemma3n_per_layer_proj_src1_name = "per_layer_proj";
|
const std::string gemma3n_per_layer_proj_src1_name = "per_layer_proj";
|
||||||
|
|
@ -2839,33 +2838,11 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (node->op == GGML_OP_CPY) {
|
|
||||||
|
|
||||||
// Store the pointers which are updated for each token, such that these can be sent
|
|
||||||
// to the device and accessed using indirection from CUDA graph
|
|
||||||
cuda_ctx->cuda_graph->cpy_dest_ptrs.push_back((char *) node->src[1]->data);
|
|
||||||
|
|
||||||
// store a pointer to each copy op CUDA kernel to identify it later
|
|
||||||
void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
|
|
||||||
if (!ptr) {
|
|
||||||
use_cuda_graph = false;
|
|
||||||
#ifndef NDEBUG
|
|
||||||
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported copy op\n", __func__);
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!use_cuda_graph) {
|
if (!use_cuda_graph) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (use_cuda_graph) {
|
|
||||||
cuda_ctx->cuda_graph->use_cpy_indirection = true;
|
|
||||||
// copy pointers to GPU so they can be accessed via indirection within CUDA graph
|
|
||||||
ggml_cuda_cpy_dest_ptrs_copy(cuda_ctx->cuda_graph.get(), cuda_ctx->cuda_graph->cpy_dest_ptrs.data(), cuda_ctx->cuda_graph->cpy_dest_ptrs.size(), cuda_ctx->stream());
|
|
||||||
}
|
|
||||||
|
|
||||||
return use_cuda_graph;
|
return use_cuda_graph;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -2884,7 +2861,6 @@ static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_p
|
||||||
|
|
||||||
static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
|
static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
|
||||||
if (node->data != graph_node_properties->node_address &&
|
if (node->data != graph_node_properties->node_address &&
|
||||||
node->op != GGML_OP_CPY &&
|
|
||||||
node->op != GGML_OP_VIEW) {
|
node->op != GGML_OP_VIEW) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
@ -2905,7 +2881,6 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
|
||||||
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
||||||
if (node->src[i] &&
|
if (node->src[i] &&
|
||||||
node->src[i]->data != graph_node_properties->src_address[i] &&
|
node->src[i]->data != graph_node_properties->src_address[i] &&
|
||||||
node->op != GGML_OP_CPY &&
|
|
||||||
node->op != GGML_OP_VIEW
|
node->op != GGML_OP_VIEW
|
||||||
) {
|
) {
|
||||||
return false;
|
return false;
|
||||||
|
|
@ -2985,18 +2960,15 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
//TODO: remove special case once ggml_can_fuse can handle empty nodes
|
//TODO: remove special case once ggml_can_fuse can handle empty nodes
|
||||||
std::initializer_list<enum ggml_op> topk_moe_ops = ggml_cuda_topk_moe_ops(false);
|
std::initializer_list<enum ggml_op> topk_moe_ops =
|
||||||
std::initializer_list<enum ggml_op> topk_moe_ops_with_norm = ggml_cuda_topk_moe_ops(true);
|
ggml_cuda_topk_moe_ops(/*with_norm*/ false, /*delayed_softmax=*/false);
|
||||||
|
std::initializer_list<enum ggml_op> topk_moe_ops_with_norm =
|
||||||
|
ggml_cuda_topk_moe_ops(/*with_norm=*/true, /*delayed_softmax=*/false);
|
||||||
|
std::initializer_list<enum ggml_op> topk_moe_ops_delayed_softmax =
|
||||||
|
ggml_cuda_topk_moe_ops(/*with_norm=*/false, /*delayed_softmax=*/true);
|
||||||
|
|
||||||
if (ops.size() == topk_moe_ops_with_norm.size() && std::equal(ops.begin(), ops.end(), topk_moe_ops_with_norm.begin())) {
|
if (ops.size() == topk_moe_ops_with_norm.size() &&
|
||||||
|
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 8 })) {
|
||||||
if (node_idx + topk_moe_ops_with_norm.size() > (size_t)cgraph->n_nodes) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (size_t i = 0; i < topk_moe_ops_with_norm.size(); i++) {
|
|
||||||
if (cgraph->nodes[node_idx + i]->op != topk_moe_ops_with_norm.begin()[i]) return false;
|
|
||||||
}
|
|
||||||
ggml_tensor * softmax = cgraph->nodes[node_idx];
|
ggml_tensor * softmax = cgraph->nodes[node_idx];
|
||||||
ggml_tensor * weights = cgraph->nodes[node_idx+8];
|
ggml_tensor * weights = cgraph->nodes[node_idx+8];
|
||||||
|
|
||||||
|
|
@ -3005,16 +2977,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ops.size() == topk_moe_ops.size() && std::equal(ops.begin(), ops.end(), topk_moe_ops.begin())) {
|
if (ops.size() == topk_moe_ops.size() &&
|
||||||
|
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) {
|
||||||
if (node_idx + topk_moe_ops.size() > (size_t)cgraph->n_nodes) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (size_t i = 0; i < topk_moe_ops.size(); i++) {
|
|
||||||
if (cgraph->nodes[node_idx + i]->op != topk_moe_ops.begin()[i]) return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_tensor * softmax = cgraph->nodes[node_idx];
|
ggml_tensor * softmax = cgraph->nodes[node_idx];
|
||||||
ggml_tensor * weights = cgraph->nodes[node_idx+4];
|
ggml_tensor * weights = cgraph->nodes[node_idx+4];
|
||||||
if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
|
if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
|
||||||
|
|
@ -3022,6 +2986,16 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (ops.size() == topk_moe_ops_delayed_softmax.size() &&
|
||||||
|
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 2, node_idx + 5 })) {
|
||||||
|
ggml_tensor * softmax = cgraph->nodes[node_idx + 4];
|
||||||
|
ggml_tensor * weights = cgraph->nodes[node_idx + 5];
|
||||||
|
|
||||||
|
if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (!ggml_can_fuse(cgraph, node_idx, ops)) {
|
if (!ggml_can_fuse(cgraph, node_idx, ops)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
@ -3052,7 +3026,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
|
||||||
}
|
}
|
||||||
|
|
||||||
//if rms norm is the B operand, then we don't handle broadcast
|
//if rms norm is the B operand, then we don't handle broadcast
|
||||||
if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) {
|
if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -3121,7 +3095,8 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
|
||||||
if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ true), {})) {
|
if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ true), {})) {
|
||||||
ggml_tensor * weights = cgraph->nodes[i+8];
|
ggml_tensor * weights = cgraph->nodes[i+8];
|
||||||
ggml_tensor * selected_experts = cgraph->nodes[i+3];
|
ggml_tensor * selected_experts = cgraph->nodes[i+3];
|
||||||
ggml_cuda_op_topk_moe(*cuda_ctx, node, weights, selected_experts, /*with norm*/ true);
|
ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ true,
|
||||||
|
/*delayed softmax*/ false);
|
||||||
i += 8;
|
i += 8;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
@ -3129,11 +3104,23 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
|
||||||
if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ false), {})) {
|
if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ false), {})) {
|
||||||
ggml_tensor * weights = cgraph->nodes[i+4];
|
ggml_tensor * weights = cgraph->nodes[i+4];
|
||||||
ggml_tensor * selected_experts = cgraph->nodes[i+3];
|
ggml_tensor * selected_experts = cgraph->nodes[i+3];
|
||||||
ggml_cuda_op_topk_moe(*cuda_ctx, node, weights, selected_experts, /*with norm*/ false);
|
ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ false,
|
||||||
|
/*delayed softmax*/ false);
|
||||||
i += 4;
|
i += 4;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (ggml_cuda_can_fuse(cgraph, i,
|
||||||
|
ggml_cuda_topk_moe_ops(/*with norm*/ false, /*delayed softmax*/ true), {})) {
|
||||||
|
ggml_tensor * weights = cgraph->nodes[i + 5];
|
||||||
|
ggml_tensor * ids = cgraph->nodes[i + 1];
|
||||||
|
|
||||||
|
ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, ids, /*with norm*/ false,
|
||||||
|
/*delayed_softmax*/ true);
|
||||||
|
i += 5;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
if (node->op == GGML_OP_ADD) {
|
if (node->op == GGML_OP_ADD) {
|
||||||
int n_fuse = 0;
|
int n_fuse = 0;
|
||||||
ggml_op ops[8];
|
ggml_op ops[8];
|
||||||
|
|
@ -3278,7 +3265,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
|
||||||
if (use_cuda_graph) {
|
if (use_cuda_graph) {
|
||||||
cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph);
|
cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph);
|
||||||
|
|
||||||
use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops(cuda_ctx, cgraph, batch_size, use_cuda_graph);
|
use_cuda_graph = check_node_graph_compatibility(cgraph, batch_size, use_cuda_graph);
|
||||||
|
|
||||||
// Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
|
// Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
|
||||||
if (use_cuda_graph && cuda_graph_update_required) {
|
if (use_cuda_graph && cuda_graph_update_required) {
|
||||||
|
|
@ -3305,10 +3292,6 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
|
||||||
CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
|
CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!use_cuda_graph) {
|
|
||||||
cuda_ctx->cuda_graph->use_cpy_indirection = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
#else
|
#else
|
||||||
bool use_cuda_graph = false;
|
bool use_cuda_graph = false;
|
||||||
bool cuda_graph_update_required = false;
|
bool cuda_graph_update_required = false;
|
||||||
|
|
@ -3922,12 +3905,16 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||||
case GGML_OP_CONV_2D_DW:
|
case GGML_OP_CONV_2D_DW:
|
||||||
case GGML_OP_CONV_TRANSPOSE_2D:
|
case GGML_OP_CONV_TRANSPOSE_2D:
|
||||||
case GGML_OP_POOL_2D:
|
case GGML_OP_POOL_2D:
|
||||||
case GGML_OP_SUM:
|
|
||||||
case GGML_OP_ACC:
|
case GGML_OP_ACC:
|
||||||
return true;
|
return true;
|
||||||
|
case GGML_OP_SUM:
|
||||||
|
return ggml_is_contiguous_rows(op->src[0]);
|
||||||
case GGML_OP_ARGSORT:
|
case GGML_OP_ARGSORT:
|
||||||
// TODO: Support arbitrary column width
|
#ifndef GGML_CUDA_USE_CUB
|
||||||
return op->src[0]->ne[0] <= 1024;
|
return op->src[0]->ne[0] <= 1024;
|
||||||
|
#else
|
||||||
|
return true;
|
||||||
|
#endif
|
||||||
case GGML_OP_SUM_ROWS:
|
case GGML_OP_SUM_ROWS:
|
||||||
case GGML_OP_MEAN:
|
case GGML_OP_MEAN:
|
||||||
case GGML_OP_GROUP_NORM:
|
case GGML_OP_GROUP_NORM:
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,7 @@
|
||||||
#include "ggml.h"
|
#include "ggml.h"
|
||||||
#include "mmf.cuh"
|
#include "mmf.cuh"
|
||||||
|
#include "mmid.cuh"
|
||||||
|
|
||||||
|
|
||||||
void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
|
void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
|
||||||
GGML_ASSERT( src1->type == GGML_TYPE_F32);
|
GGML_ASSERT( src1->type == GGML_TYPE_F32);
|
||||||
|
|
@ -37,6 +39,12 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
|
||||||
const int64_t ids_s0 = ids ? ids->nb[0] / ggml_type_size(ids->type) : 0;
|
const int64_t ids_s0 = ids ? ids->nb[0] / ggml_type_size(ids->type) : 0;
|
||||||
const int64_t ids_s1 = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0;
|
const int64_t ids_s1 = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0;
|
||||||
|
|
||||||
|
mmf_ids_data ids_info{};
|
||||||
|
mmf_ids_data * ids_info_ptr = nullptr;
|
||||||
|
ggml_cuda_pool_alloc<int32_t> ids_src_compact_dev;
|
||||||
|
ggml_cuda_pool_alloc<int32_t> ids_dst_compact_dev;
|
||||||
|
ggml_cuda_pool_alloc<int32_t> expert_bounds_dev;
|
||||||
|
|
||||||
// For MUL_MAT_ID the memory layout is different than for MUL_MAT:
|
// For MUL_MAT_ID the memory layout is different than for MUL_MAT:
|
||||||
const int64_t ncols_dst = ids ? ne2 : ne1;
|
const int64_t ncols_dst = ids ? ne2 : ne1;
|
||||||
const int64_t nchannels_dst = ids ? ne1 : ne2;
|
const int64_t nchannels_dst = ids ? ne1 : ne2;
|
||||||
|
|
@ -54,6 +62,33 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
|
||||||
nchannels_y = ids->ne[0];
|
nchannels_y = ids->ne[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (ids && ncols_dst > 16) {
|
||||||
|
const int64_t n_expert_used = ids->ne[0];
|
||||||
|
const int64_t n_experts = ne02;
|
||||||
|
const int64_t n_tokens = ne12;
|
||||||
|
const int64_t ne_get_rows = n_tokens * n_expert_used;
|
||||||
|
|
||||||
|
ids_src_compact_dev.alloc(ctx.pool(), ne_get_rows);
|
||||||
|
ids_dst_compact_dev.alloc(ctx.pool(), ne_get_rows);
|
||||||
|
expert_bounds_dev.alloc(ctx.pool(), n_experts + 1);
|
||||||
|
|
||||||
|
const int si1 = static_cast<int>(ids_s1);
|
||||||
|
const int sis1 = static_cast<int>(src1->nb[2] / src1->nb[1]);
|
||||||
|
|
||||||
|
GGML_ASSERT(sis1 > 0);
|
||||||
|
|
||||||
|
ggml_cuda_launch_mm_ids_helper(ids_d, ids_src_compact_dev.get(), ids_dst_compact_dev.get(), expert_bounds_dev.get(),
|
||||||
|
static_cast<int>(n_experts), static_cast<int>(n_tokens), static_cast<int>(n_expert_used), static_cast<int>(ne11), si1, sis1, ctx.stream());
|
||||||
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
|
||||||
|
ids_info.ids_src_compact = ids_src_compact_dev.get();
|
||||||
|
ids_info.ids_dst_compact = ids_dst_compact_dev.get();
|
||||||
|
ids_info.expert_bounds_dev = expert_bounds_dev.get();
|
||||||
|
ids_info.n_experts = static_cast<int>(n_experts);
|
||||||
|
ids_info.sis1 = sis1;
|
||||||
|
ids_info_ptr = &ids_info;
|
||||||
|
}
|
||||||
|
|
||||||
switch (src0->type) {
|
switch (src0->type) {
|
||||||
case GGML_TYPE_F32: {
|
case GGML_TYPE_F32: {
|
||||||
const float * src0_d = (const float *) src0->data;
|
const float * src0_d = (const float *) src0->data;
|
||||||
|
|
@ -61,7 +96,7 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
|
||||||
mul_mat_f_switch_cols_per_block(
|
mul_mat_f_switch_cols_per_block(
|
||||||
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
|
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
|
||||||
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
|
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
|
||||||
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
|
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_F16: {
|
case GGML_TYPE_F16: {
|
||||||
const half2 * src0_d = (const half2 *) src0->data;
|
const half2 * src0_d = (const half2 *) src0->data;
|
||||||
|
|
@ -69,7 +104,7 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
|
||||||
mul_mat_f_switch_cols_per_block(
|
mul_mat_f_switch_cols_per_block(
|
||||||
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
|
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
|
||||||
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
|
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
|
||||||
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
|
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_BF16: {
|
case GGML_TYPE_BF16: {
|
||||||
const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data;
|
const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data;
|
||||||
|
|
@ -77,7 +112,7 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
|
||||||
mul_mat_f_switch_cols_per_block(
|
mul_mat_f_switch_cols_per_block(
|
||||||
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
|
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
|
||||||
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
|
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
|
||||||
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
|
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
|
||||||
} break;
|
} break;
|
||||||
default:
|
default:
|
||||||
GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
|
GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
|
||||||
|
|
@ -98,10 +133,9 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
|
||||||
}
|
}
|
||||||
|
|
||||||
if (mul_mat_id) {
|
if (mul_mat_id) {
|
||||||
if (type == GGML_TYPE_F32 && src1_ncols > 32) {
|
if (src0_ne[1] <= 1024 && src1_ncols > 512) {
|
||||||
return false;
|
return false;
|
||||||
}
|
} else if(src0_ne[1] > 1024 && src1_ncols > 128) {
|
||||||
if ((type == GGML_TYPE_F16 || type == GGML_TYPE_BF16) && src1_ncols > 64) {
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,14 @@ using namespace ggml_cuda_mma;
|
||||||
|
|
||||||
#define MMF_ROWS_PER_BLOCK 32
|
#define MMF_ROWS_PER_BLOCK 32
|
||||||
|
|
||||||
|
struct mmf_ids_data {
|
||||||
|
const int32_t * ids_src_compact = nullptr;
|
||||||
|
const int32_t * ids_dst_compact = nullptr;
|
||||||
|
const int32_t * expert_bounds_dev = nullptr;
|
||||||
|
int n_experts = 0;
|
||||||
|
int sis1 = 0;
|
||||||
|
};
|
||||||
|
|
||||||
void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
|
void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
|
||||||
|
|
||||||
bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const int src1_ncols, bool mul_mat_id);
|
bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const int src1_ncols, bool mul_mat_id);
|
||||||
|
|
@ -224,6 +232,250 @@ static __global__ void mul_mat_f(
|
||||||
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
//This kernel is for larger batch sizes of mul_mat_id
|
||||||
|
template <typename T, int rows_per_block, int cols_per_block, int nwarps>
|
||||||
|
__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1)
|
||||||
|
static __global__ void mul_mat_f_ids(
|
||||||
|
const T * __restrict__ x, const float * __restrict__ y,
|
||||||
|
const int32_t * __restrict__ ids_src_compact, const int32_t * __restrict__ ids_dst_compact,
|
||||||
|
const int32_t * __restrict__ expert_bounds, float * __restrict__ dst,
|
||||||
|
const int ncols, const int ncols_dst_total, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst,
|
||||||
|
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
||||||
|
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
|
||||||
|
const uint3 sis1_fd, const uint3 nch_fd) {
|
||||||
|
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||||
|
typedef tile<16, 8, T> tile_A;
|
||||||
|
typedef tile< 8, 8, T> tile_B;
|
||||||
|
typedef tile<16, 8, float> tile_C;
|
||||||
|
|
||||||
|
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||||
|
constexpr int tile_k_padded = warp_size + 4;
|
||||||
|
constexpr int ntA = rows_per_block / tile_A::I;
|
||||||
|
constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I;
|
||||||
|
|
||||||
|
const int row0 = blockIdx.x * rows_per_block;
|
||||||
|
|
||||||
|
const int expert_idx = blockIdx.y;
|
||||||
|
const int expert_start = expert_bounds[expert_idx];
|
||||||
|
const int expert_end = expert_bounds[expert_idx + 1];
|
||||||
|
const int ncols_expert = expert_end - expert_start;
|
||||||
|
|
||||||
|
const int tiles_for_expert = (ncols_expert + cols_per_block - 1) / cols_per_block;
|
||||||
|
const int tile_idx = blockIdx.z;
|
||||||
|
if (tile_idx >= tiles_for_expert) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int col_base = tile_idx * cols_per_block;
|
||||||
|
|
||||||
|
GGML_UNUSED(channel_ratio);
|
||||||
|
|
||||||
|
const int channel_x = expert_idx;
|
||||||
|
const int sample_dst = 0;
|
||||||
|
const int sample_x = sample_dst / sample_ratio;
|
||||||
|
const int sample_y = sample_dst;
|
||||||
|
|
||||||
|
x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row0*stride_row;
|
||||||
|
y += int64_t(sample_y) *stride_sample_y;
|
||||||
|
dst += int64_t(sample_dst)*stride_sample_dst;
|
||||||
|
|
||||||
|
const int32_t * ids_src_expert = ids_src_compact + expert_start;
|
||||||
|
const int32_t * ids_dst_expert = ids_dst_compact + expert_start;
|
||||||
|
|
||||||
|
extern __shared__ char data_mmv[];
|
||||||
|
char * compute_base = data_mmv;
|
||||||
|
|
||||||
|
//const float2 * y2 = (const float2 *) y;
|
||||||
|
|
||||||
|
tile_C C[ntA][ntB];
|
||||||
|
|
||||||
|
T * tile_xy = (T *) compute_base + threadIdx.y*(tile_A::I * tile_k_padded);
|
||||||
|
|
||||||
|
for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) {
|
||||||
|
tile_A A[ntA][warp_size / tile_A::J];
|
||||||
|
#pragma unroll
|
||||||
|
for (int itA = 0; itA < ntA; ++itA) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < tile_A::I; ++i) {
|
||||||
|
tile_xy[i*tile_k_padded + threadIdx.x] = x[(itA*tile_A::I + i)*stride_row + col];
|
||||||
|
}
|
||||||
|
#pragma unroll
|
||||||
|
for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) {
|
||||||
|
load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if constexpr (std::is_same_v<T, float>) {
|
||||||
|
float vals_buf[2][tile_B::I];
|
||||||
|
auto gather_tile = [&](int tile_idx_local, float *vals) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int j0 = 0; j0 < tile_B::I; ++j0) {
|
||||||
|
const int j = j0 + tile_idx_local*tile_B::I;
|
||||||
|
const int global_j = col_base + j;
|
||||||
|
float val = 0.0f;
|
||||||
|
if (j < cols_per_block && global_j < ncols_expert) {
|
||||||
|
const int src_entry = ids_src_expert[global_j];
|
||||||
|
const uint2 qrm = fast_div_modulo((uint32_t) src_entry, sis1_fd);
|
||||||
|
const int token = (int) qrm.x;
|
||||||
|
const int channel = (int) qrm.y;
|
||||||
|
if (token < ncols_dst_total) {
|
||||||
|
val = y[channel*stride_channel_y + token*stride_col_y + col];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
vals[j0] = val;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
gather_tile(0, vals_buf[0]);
|
||||||
|
|
||||||
|
int curr_buf = 0;
|
||||||
|
int next_buf = 1;
|
||||||
|
#pragma unroll
|
||||||
|
for (int itB = 0; itB < ntB; ++itB) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int j0 = 0; j0 < tile_B::I; ++j0) {
|
||||||
|
tile_xy[j0*tile_k_padded + threadIdx.x] = vals_buf[curr_buf][j0];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (itB + 1 < ntB) {
|
||||||
|
gather_tile(itB + 1, vals_buf[next_buf]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) {
|
||||||
|
tile_B B;
|
||||||
|
load_ldmatrix(B, tile_xy + k0, tile_k_padded);
|
||||||
|
#pragma unroll
|
||||||
|
for (int itA = 0; itA < ntA; ++itA) {
|
||||||
|
mma(C[itA][itB], A[itA][k0/tile_B::J], B);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (itB + 1 < ntB) {
|
||||||
|
curr_buf ^= 1;
|
||||||
|
next_buf ^= 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
|
||||||
|
float2 vals_buf[2][tile_B::I];
|
||||||
|
auto gather_tile = [&](int tile_idx_local, float2 *vals) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int j0 = 0; j0 < tile_B::I; ++j0) {
|
||||||
|
const int j = j0 + tile_idx_local*tile_B::I;
|
||||||
|
const int global_j = col_base + j;
|
||||||
|
float2 tmp = make_float2(0.0f, 0.0f);
|
||||||
|
if (j < cols_per_block && global_j < ncols_expert) {
|
||||||
|
const int src_entry = ids_src_expert[global_j];
|
||||||
|
const uint2 qrm = fast_div_modulo((uint32_t) src_entry, sis1_fd);
|
||||||
|
const int token = (int) qrm.x;
|
||||||
|
const int channel = (int) qrm.y;
|
||||||
|
if (token < ncols_dst_total) {
|
||||||
|
tmp = *(const float2*) &y[channel*stride_channel_y + 2*(token*stride_col_y + col)];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
vals[j0] = tmp;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if (ntB > 0) {
|
||||||
|
gather_tile(0, vals_buf[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
int curr_buf = 0;
|
||||||
|
int next_buf = 1;
|
||||||
|
#pragma unroll
|
||||||
|
for (int itB = 0; itB < ntB; ++itB) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int j0 = 0; j0 < tile_B::I; ++j0) {
|
||||||
|
const float2 tmp = vals_buf[curr_buf][j0];
|
||||||
|
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
|
||||||
|
}
|
||||||
|
|
||||||
|
if (itB + 1 < ntB) {
|
||||||
|
gather_tile(itB + 1, vals_buf[next_buf]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) {
|
||||||
|
tile_B B;
|
||||||
|
load_ldmatrix(B, tile_xy + k0, tile_k_padded);
|
||||||
|
#pragma unroll
|
||||||
|
for (int itA = 0; itA < ntA; ++itA) {
|
||||||
|
mma(C[itA][itB], A[itA][k0/tile_B::J], B);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (itB + 1 < ntB) {
|
||||||
|
curr_buf ^= 1;
|
||||||
|
next_buf ^= 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
static_assert(std::is_same_v<T, void>, "unsupported type");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
float * buf_iw = (float *) compute_base;
|
||||||
|
constexpr int kiw = nwarps*rows_per_block + 4;
|
||||||
|
|
||||||
|
if (nwarps > 1) {
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
#pragma unroll
|
||||||
|
for (int itB = 0; itB < ntB; ++itB) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int itA = 0; itA < ntA; ++itA) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int l = 0; l < tile_C::ne; ++l) {
|
||||||
|
const int i = threadIdx.y*rows_per_block + itA*tile_C::I + tile_C::get_i(l);
|
||||||
|
const int j = itB*tile_C::J + tile_C::get_j(l);
|
||||||
|
buf_iw[j*kiw + i] = C[itA][itB].x[l];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (nwarps > 1) {
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) {
|
||||||
|
const int j = j0 + threadIdx.y;
|
||||||
|
|
||||||
|
if (j0 + nwarps > cols_per_block && j >= cols_per_block) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
float sum = 0.0f;
|
||||||
|
static_assert(rows_per_block == warp_size, "need loop/check");
|
||||||
|
#pragma unroll
|
||||||
|
for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {
|
||||||
|
const int i = i0 + threadIdx.x;
|
||||||
|
|
||||||
|
sum += buf_iw[j*kiw + i];
|
||||||
|
}
|
||||||
|
|
||||||
|
const int global_j = col_base + j;
|
||||||
|
if (j < cols_per_block && global_j < ncols_expert && nchannels_dst > 0) {
|
||||||
|
const int dst_entry = ids_dst_expert[global_j];
|
||||||
|
const uint2 qrm = fast_div_modulo((uint32_t) dst_entry, nch_fd);
|
||||||
|
const int token = (int) qrm.x;
|
||||||
|
if (token < ncols_dst_total) {
|
||||||
|
const int slot = (int) qrm.y;
|
||||||
|
dst[slot*stride_channel_dst + token*stride_col_dst + row0 + threadIdx.x] = sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
GGML_UNUSED_VARS(x, y, ids_src_compact, ids_dst_compact, expert_bounds, dst,
|
||||||
|
ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
|
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd);
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||||
|
}
|
||||||
|
|
||||||
template<typename T, int cols_per_block, int nwarps>
|
template<typename T, int cols_per_block, int nwarps>
|
||||||
static inline void mul_mat_f_switch_ids(
|
static inline void mul_mat_f_switch_ids(
|
||||||
const T * x, const float * y, const int32_t * ids, float * dst,
|
const T * x, const float * y, const int32_t * ids, float * dst,
|
||||||
|
|
@ -232,11 +484,33 @@ static inline void mul_mat_f_switch_ids(
|
||||||
const int64_t stride_col_id, const int64_t stride_row_id,
|
const int64_t stride_col_id, const int64_t stride_row_id,
|
||||||
const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
|
const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
|
||||||
const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
||||||
const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream) {
|
const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream,
|
||||||
if (ids) {
|
const mmf_ids_data * ids_data) {
|
||||||
|
const bool has_ids_data = ids_data && ids_data->ids_src_compact;
|
||||||
|
|
||||||
|
// Use the compact-ids kernel only for larger tiles; for small ncols_dst (< 16)
|
||||||
|
// we prefer the normal mul_mat_f path with has_ids=true.
|
||||||
|
if (has_ids_data && ncols_dst > 16) {
|
||||||
|
const int max_tiles = (int) ((ncols_dst + cols_per_block - 1) / cols_per_block);
|
||||||
|
if (max_tiles == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
dim3 block_nums_ids(block_nums.x, ids_data->n_experts, max_tiles);
|
||||||
|
|
||||||
|
const uint3 sis1_fd = ids_data->sis1 > 0 ? init_fastdiv_values((uint32_t) ids_data->sis1) : make_uint3(0, 0, 1);
|
||||||
|
const uint3 nch_fd = init_fastdiv_values((uint32_t) nchannels_dst);
|
||||||
|
|
||||||
|
mul_mat_f_ids<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>
|
||||||
|
(x, y, ids_data->ids_src_compact, ids_data->ids_dst_compact, ids_data->expert_bounds_dev, dst,
|
||||||
|
ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
|
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||||
|
sis1_fd, nch_fd);
|
||||||
|
} else if (ids) {
|
||||||
const int64_t col_tiles = (ncols_dst + cols_per_block - 1) / cols_per_block;
|
const int64_t col_tiles = (ncols_dst + cols_per_block - 1) / cols_per_block;
|
||||||
dim3 block_nums_ids = block_nums;
|
dim3 block_nums_ids = block_nums;
|
||||||
block_nums_ids.y *= col_tiles;
|
block_nums_ids.y *= col_tiles;
|
||||||
|
|
||||||
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>
|
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>
|
||||||
(x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
(x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
|
|
@ -258,7 +532,7 @@ void mul_mat_f_cuda(
|
||||||
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
||||||
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
||||||
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
||||||
cudaStream_t stream) {
|
cudaStream_t stream, const mmf_ids_data * ids_data) {
|
||||||
typedef tile<16, 8, T> tile_A;
|
typedef tile<16, 8, T> tile_A;
|
||||||
typedef tile< 8, 8, T> tile_B;
|
typedef tile< 8, 8, T> tile_B;
|
||||||
|
|
||||||
|
|
@ -290,7 +564,7 @@ void mul_mat_f_cuda(
|
||||||
const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);
|
const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);
|
||||||
const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0;
|
const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0;
|
||||||
const int nbytes_shared_total = nbytes_shared + nbytes_slotmap;
|
const int nbytes_shared_total = nbytes_shared + nbytes_slotmap;
|
||||||
const int64_t grid_y = ids ? nchannels_x : nchannels_dst; // per expert when ids present
|
const int64_t grid_y = ids ? nchannels_x : nchannels_dst;
|
||||||
|
|
||||||
const dim3 block_nums(nrows_x/rows_per_block, grid_y, nsamples_dst);
|
const dim3 block_nums(nrows_x/rows_per_block, grid_y, nsamples_dst);
|
||||||
const dim3 block_dims(warp_size, nwarps_best, 1);
|
const dim3 block_dims(warp_size, nwarps_best, 1);
|
||||||
|
|
@ -300,49 +574,57 @@ void mul_mat_f_cuda(
|
||||||
mul_mat_f_switch_ids<T, cols_per_block, 1>(
|
mul_mat_f_switch_ids<T, cols_per_block, 1>(
|
||||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
||||||
|
ids_data);
|
||||||
} break;
|
} break;
|
||||||
case 2: {
|
case 2: {
|
||||||
mul_mat_f_switch_ids<T, cols_per_block, 2>(
|
mul_mat_f_switch_ids<T, cols_per_block, 2>(
|
||||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
||||||
|
ids_data);
|
||||||
} break;
|
} break;
|
||||||
case 3: {
|
case 3: {
|
||||||
mul_mat_f_switch_ids<T, cols_per_block, 3>(
|
mul_mat_f_switch_ids<T, cols_per_block, 3>(
|
||||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
||||||
|
ids_data);
|
||||||
} break;
|
} break;
|
||||||
case 4: {
|
case 4: {
|
||||||
mul_mat_f_switch_ids<T, cols_per_block, 4>(
|
mul_mat_f_switch_ids<T, cols_per_block, 4>(
|
||||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
||||||
|
ids_data);
|
||||||
} break;
|
} break;
|
||||||
case 5: {
|
case 5: {
|
||||||
mul_mat_f_switch_ids<T, cols_per_block, 5>(
|
mul_mat_f_switch_ids<T, cols_per_block, 5>(
|
||||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
||||||
|
ids_data);
|
||||||
} break;
|
} break;
|
||||||
case 6: {
|
case 6: {
|
||||||
mul_mat_f_switch_ids<T, cols_per_block, 6>(
|
mul_mat_f_switch_ids<T, cols_per_block, 6>(
|
||||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
||||||
|
ids_data);
|
||||||
} break;
|
} break;
|
||||||
case 7: {
|
case 7: {
|
||||||
mul_mat_f_switch_ids<T, cols_per_block, 7>(
|
mul_mat_f_switch_ids<T, cols_per_block, 7>(
|
||||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
||||||
|
ids_data);
|
||||||
} break;
|
} break;
|
||||||
case 8: {
|
case 8: {
|
||||||
mul_mat_f_switch_ids<T, cols_per_block, 8>(
|
mul_mat_f_switch_ids<T, cols_per_block, 8>(
|
||||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
||||||
|
ids_data);
|
||||||
} break;
|
} break;
|
||||||
default: {
|
default: {
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
|
|
@ -361,7 +643,7 @@ static void mul_mat_f_switch_cols_per_block(
|
||||||
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
||||||
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
||||||
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
||||||
cudaStream_t stream) {
|
cudaStream_t stream, const mmf_ids_data * ids_data) {
|
||||||
|
|
||||||
const int ncols_case = (ids && ncols_dst > 16) ? 16 : ncols_dst;
|
const int ncols_case = (ids && ncols_dst > 16) ? 16 : ncols_dst;
|
||||||
|
|
||||||
|
|
@ -371,82 +653,82 @@ static void mul_mat_f_switch_cols_per_block(
|
||||||
case 1: {
|
case 1: {
|
||||||
mul_mat_f_cuda<T, 1>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
mul_mat_f_cuda<T, 1>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||||
} break;
|
} break;
|
||||||
case 2: {
|
case 2: {
|
||||||
mul_mat_f_cuda<T, 2>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
mul_mat_f_cuda<T, 2>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||||
} break;
|
} break;
|
||||||
case 3: {
|
case 3: {
|
||||||
mul_mat_f_cuda<T, 3>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
mul_mat_f_cuda<T, 3>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||||
} break;
|
} break;
|
||||||
case 4: {
|
case 4: {
|
||||||
mul_mat_f_cuda<T, 4>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
mul_mat_f_cuda<T, 4>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||||
} break;
|
} break;
|
||||||
case 5: {
|
case 5: {
|
||||||
mul_mat_f_cuda<T, 5>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
mul_mat_f_cuda<T, 5>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||||
} break;
|
} break;
|
||||||
case 6: {
|
case 6: {
|
||||||
mul_mat_f_cuda<T, 6>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
mul_mat_f_cuda<T, 6>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||||
} break;
|
} break;
|
||||||
case 7: {
|
case 7: {
|
||||||
mul_mat_f_cuda<T, 7>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
mul_mat_f_cuda<T, 7>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||||
} break;
|
} break;
|
||||||
case 8: {
|
case 8: {
|
||||||
mul_mat_f_cuda<T, 8>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
mul_mat_f_cuda<T, 8>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||||
} break;
|
} break;
|
||||||
case 9: {
|
case 9: {
|
||||||
mul_mat_f_cuda<T, 9>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
mul_mat_f_cuda<T, 9>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||||
} break;
|
} break;
|
||||||
case 10: {
|
case 10: {
|
||||||
mul_mat_f_cuda<T, 10>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
mul_mat_f_cuda<T, 10>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||||
} break;
|
} break;
|
||||||
case 11: {
|
case 11: {
|
||||||
mul_mat_f_cuda<T, 11>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
mul_mat_f_cuda<T, 11>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||||
} break;
|
} break;
|
||||||
case 12: {
|
case 12: {
|
||||||
mul_mat_f_cuda<T, 12>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
mul_mat_f_cuda<T, 12>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||||
} break;
|
} break;
|
||||||
case 13: {
|
case 13: {
|
||||||
mul_mat_f_cuda<T, 13>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
mul_mat_f_cuda<T, 13>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||||
} break;
|
} break;
|
||||||
case 14: {
|
case 14: {
|
||||||
mul_mat_f_cuda<T, 14>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
mul_mat_f_cuda<T, 14>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||||
} break;
|
} break;
|
||||||
case 15: {
|
case 15: {
|
||||||
mul_mat_f_cuda<T, 15>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
mul_mat_f_cuda<T, 15>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||||
} break;
|
} break;
|
||||||
case 16: {
|
case 16: {
|
||||||
mul_mat_f_cuda<T, 16>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
mul_mat_f_cuda<T, 16>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||||
} break;
|
} break;
|
||||||
default: {
|
default: {
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
|
|
@ -462,7 +744,7 @@ static void mul_mat_f_switch_cols_per_block(
|
||||||
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, \
|
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, \
|
||||||
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,\
|
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,\
|
||||||
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, \
|
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, \
|
||||||
cudaStream_t stream);
|
cudaStream_t stream, const mmf_ids_data * ids_data);
|
||||||
|
|
||||||
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||||
#define DECL_MMF_CASE_EXTERN(ncols_dst) \
|
#define DECL_MMF_CASE_EXTERN(ncols_dst) \
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,164 @@
|
||||||
|
#include "common.cuh"
|
||||||
|
#include "mmid.cuh"
|
||||||
|
|
||||||
|
// To reduce shared memory use, store "it" and "iex_used" with 22/10 bits each.
|
||||||
|
struct mm_ids_helper_store {
|
||||||
|
uint32_t data;
|
||||||
|
|
||||||
|
__device__ mm_ids_helper_store(const uint32_t it, const uint32_t iex_used) {
|
||||||
|
data = (it & 0x003FFFFF) | (iex_used << 22);
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ uint32_t it() const {
|
||||||
|
return data & 0x003FFFFF;
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ uint32_t iex_used() const {
|
||||||
|
return data >> 22;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
static_assert(sizeof(mm_ids_helper_store) == 4, "unexpected size for mm_ids_helper_store");
|
||||||
|
|
||||||
|
// Helper function for mul_mat_id, converts ids to a more convenient format.
|
||||||
|
// ids_src1 describes how to permute the flattened column indices of src1 in order to get a compact src1 tensor sorted by expert.
|
||||||
|
// ids_dst describes the same mapping but for the dst tensor.
|
||||||
|
// The upper and lower bounds for the ith expert in the compact src1 tensor are stored in expert_bounds[i:i+1].
|
||||||
|
template <int n_expert_used_template>
|
||||||
|
__launch_bounds__(ggml_cuda_get_physical_warp_size(), 1)
|
||||||
|
static __global__ void mm_ids_helper(
|
||||||
|
const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
|
||||||
|
const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1) {
|
||||||
|
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||||
|
const int n_expert_used = n_expert_used_template == 0 ? n_expert_used_var : n_expert_used_template;
|
||||||
|
const int expert = blockIdx.x;
|
||||||
|
|
||||||
|
extern __shared__ char data_mm_ids_helper[];
|
||||||
|
mm_ids_helper_store * store = (mm_ids_helper_store *) data_mm_ids_helper;
|
||||||
|
|
||||||
|
int nex_prev = 0; // Number of columns for experts with a lower index.
|
||||||
|
int it_compact = 0; // Running index for the compact slice of this expert.
|
||||||
|
|
||||||
|
if constexpr (n_expert_used_template == 0) {
|
||||||
|
// Generic implementation:
|
||||||
|
for (int it = 0; it < n_tokens; ++it) {
|
||||||
|
int iex_used = -1; // The index at which the expert is used, if any.
|
||||||
|
for (int iex = threadIdx.x; iex < n_expert_used; iex += warp_size) {
|
||||||
|
const int expert_used = ids[it*si1 + iex];
|
||||||
|
nex_prev += expert_used < expert;
|
||||||
|
if (expert_used == expert) {
|
||||||
|
iex_used = iex;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (iex_used != -1) {
|
||||||
|
store[it_compact] = mm_ids_helper_store(it, iex_used);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (warp_reduce_any<warp_size>(iex_used != -1)) {
|
||||||
|
it_compact++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Implementation optimized for specific numbers of experts used:
|
||||||
|
static_assert(n_expert_used == 6 || warp_size % n_expert_used == 0, "bad n_expert_used");
|
||||||
|
const int neu_padded = n_expert_used == 6 ? 8 : n_expert_used; // Padded to next higher power of 2.
|
||||||
|
for (int it0 = 0; it0 < n_tokens; it0 += warp_size/neu_padded) {
|
||||||
|
const int it = it0 + threadIdx.x / neu_padded;
|
||||||
|
|
||||||
|
const int iex = threadIdx.x % neu_padded; // The index at which the expert is used, if any.
|
||||||
|
const int expert_used = (neu_padded == n_expert_used || iex < n_expert_used) && it < n_tokens ?
|
||||||
|
ids[it*si1 + iex] : INT_MAX;
|
||||||
|
const int iex_used = expert_used == expert ? iex : -1;
|
||||||
|
nex_prev += expert_used < expert;
|
||||||
|
|
||||||
|
// Whether the threads at this token position have used the expert:
|
||||||
|
const int it_compact_add_self = warp_reduce_any<neu_padded>(iex_used != -1);
|
||||||
|
|
||||||
|
// Do a scan over threads at lower token positions in warp to get the correct index for writing data:
|
||||||
|
int it_compact_add_lower = 0;
|
||||||
|
#pragma unroll
|
||||||
|
for (int offset = neu_padded; offset < warp_size; offset += neu_padded) {
|
||||||
|
const int tmp = __shfl_up_sync(0xFFFFFFFF, it_compact_add_self, offset, warp_size);
|
||||||
|
if (threadIdx.x >= static_cast<unsigned int>(offset)) {
|
||||||
|
it_compact_add_lower += tmp;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (iex_used != -1) {
|
||||||
|
store[it_compact + it_compact_add_lower] = mm_ids_helper_store(it, iex_used);
|
||||||
|
}
|
||||||
|
|
||||||
|
// The thread with the highest index in the warp always has the sum over the whole warp, use it to increment all threads:
|
||||||
|
it_compact += __shfl_sync(0xFFFFFFFF, it_compact_add_lower + it_compact_add_self, warp_size - 1, warp_size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
nex_prev = warp_reduce_sum<warp_size>(nex_prev);
|
||||||
|
|
||||||
|
for (int itc = threadIdx.x; itc < it_compact; itc += warp_size) {
|
||||||
|
const mm_ids_helper_store store_it = store[itc];
|
||||||
|
const int it = store_it.it();
|
||||||
|
const int iex_used = store_it.iex_used();
|
||||||
|
ids_src1[nex_prev + itc] = it*sis1 + iex_used % nchannels_y;
|
||||||
|
ids_dst [nex_prev + itc] = it*n_expert_used + iex_used;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (threadIdx.x != 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
expert_bounds[expert] = nex_prev;
|
||||||
|
|
||||||
|
if (expert < static_cast<int>(gridDim.x) - 1) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
expert_bounds[gridDim.x] = nex_prev + it_compact;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int n_expert_used_template>
|
||||||
|
static void launch_mm_ids_helper(
|
||||||
|
const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
|
||||||
|
const int n_experts, const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) {
|
||||||
|
GGML_ASSERT(n_tokens < (1 << 22) && "too few bits in mm_ids_helper_store");
|
||||||
|
GGML_ASSERT(n_expert_used_var < (1 << 10) && "too few bits in mm_ids_helper_store");
|
||||||
|
|
||||||
|
const int id = ggml_cuda_get_device();
|
||||||
|
const int warp_size = ggml_cuda_info().devices[id].warp_size;
|
||||||
|
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
|
||||||
|
CUDA_SET_SHARED_MEMORY_LIMIT(mm_ids_helper<n_expert_used_template>, smpbo);
|
||||||
|
|
||||||
|
const dim3 num_blocks(n_experts, 1, 1);
|
||||||
|
const dim3 block_size(warp_size, 1, 1);
|
||||||
|
const size_t nbytes_shared = n_tokens*sizeof(mm_ids_helper_store);
|
||||||
|
GGML_ASSERT(nbytes_shared <= smpbo);
|
||||||
|
mm_ids_helper<n_expert_used_template><<<num_blocks, block_size, nbytes_shared, stream>>>
|
||||||
|
(ids, ids_src1, ids_dst, expert_bounds, n_tokens, n_expert_used_var, nchannels_y, si1, sis1);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_cuda_launch_mm_ids_helper(
|
||||||
|
const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
|
||||||
|
const int n_experts, const int n_tokens, const int n_expert_used, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) {
|
||||||
|
switch (n_expert_used) {
|
||||||
|
case 2:
|
||||||
|
launch_mm_ids_helper< 2>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
|
||||||
|
break;
|
||||||
|
case 4:
|
||||||
|
launch_mm_ids_helper< 4>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
|
||||||
|
break;
|
||||||
|
case 6:
|
||||||
|
launch_mm_ids_helper< 6>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
|
||||||
|
break;
|
||||||
|
case 8:
|
||||||
|
launch_mm_ids_helper< 8>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
|
||||||
|
break;
|
||||||
|
case 16:
|
||||||
|
launch_mm_ids_helper<16>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
|
||||||
|
break;
|
||||||
|
case 32:
|
||||||
|
launch_mm_ids_helper<32>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
launch_mm_ids_helper< 0>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,5 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
void ggml_cuda_launch_mm_ids_helper(
|
||||||
|
const int32_t * ids, int32_t * ids_src1, int32_t * ids_dst, int32_t * expert_bounds,
|
||||||
|
int n_experts, int n_tokens, int n_expert_used, int nchannels_y, int si1, int sis1, cudaStream_t stream);
|
||||||
|
|
@ -1,141 +1,6 @@
|
||||||
#include "mmq.cuh"
|
#include "mmq.cuh"
|
||||||
#include "quantize.cuh"
|
#include "quantize.cuh"
|
||||||
|
#include "mmid.cuh"
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
// To reduce shared memory use, store "it" and "iex_used" with 22/10 bits each.
|
|
||||||
struct mmq_ids_helper_store {
|
|
||||||
uint32_t data;
|
|
||||||
|
|
||||||
__device__ mmq_ids_helper_store(const uint32_t it, const uint32_t iex_used) {
|
|
||||||
data = (it & 0x003FFFFF) | (iex_used << 22);
|
|
||||||
}
|
|
||||||
|
|
||||||
__device__ uint32_t it() const {
|
|
||||||
return data & 0x003FFFFF;
|
|
||||||
}
|
|
||||||
|
|
||||||
__device__ uint32_t iex_used() const {
|
|
||||||
return data >> 22;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
static_assert(sizeof(mmq_ids_helper_store) == 4, "unexpected size for mmq_ids_helper_store");
|
|
||||||
|
|
||||||
// Helper function for mul_mat_id, converts ids to a more convenient format.
|
|
||||||
// ids_src1 describes how to permute the flattened column indices of src1 in order to get a compact src1 tensor sorted by expert.
|
|
||||||
// ids_dst describes the same mapping but for the dst tensor.
|
|
||||||
// The upper and lower bounds for the ith expert in the compact src1 tensor are stored in expert_bounds[i:i+1].
|
|
||||||
template <int n_expert_used_template>
|
|
||||||
__launch_bounds__(ggml_cuda_get_physical_warp_size(), 1)
|
|
||||||
static __global__ void mmq_ids_helper(
|
|
||||||
const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
|
|
||||||
const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1) {
|
|
||||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
||||||
const int n_expert_used = n_expert_used_template == 0 ? n_expert_used_var : n_expert_used_template;
|
|
||||||
const int expert = blockIdx.x;
|
|
||||||
|
|
||||||
extern __shared__ char data_mmq_ids_helper[];
|
|
||||||
mmq_ids_helper_store * store = (mmq_ids_helper_store *) data_mmq_ids_helper;
|
|
||||||
|
|
||||||
int nex_prev = 0; // Number of columns for experts with a lower index.
|
|
||||||
int it_compact = 0; // Running index for the compact slice of this expert.
|
|
||||||
|
|
||||||
if constexpr (n_expert_used_template == 0) {
|
|
||||||
// Generic implementation:
|
|
||||||
for (int it = 0; it < n_tokens; ++it) {
|
|
||||||
int iex_used = -1; // The index at which the expert is used, if any.
|
|
||||||
for (int iex = threadIdx.x; iex < n_expert_used; iex += warp_size) {
|
|
||||||
const int expert_used = ids[it*si1 + iex];
|
|
||||||
nex_prev += expert_used < expert;
|
|
||||||
if (expert_used == expert) {
|
|
||||||
iex_used = iex;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (iex_used != -1) {
|
|
||||||
store[it_compact] = mmq_ids_helper_store(it, iex_used);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (warp_reduce_any<warp_size>(iex_used != -1)) {
|
|
||||||
it_compact++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Implementation optimized for specific numbers of experts used:
|
|
||||||
static_assert(n_expert_used == 6 || warp_size % n_expert_used == 0, "bad n_expert_used");
|
|
||||||
const int neu_padded = n_expert_used == 6 ? 8 : n_expert_used; // Padded to next higher power of 2.
|
|
||||||
for (int it0 = 0; it0 < n_tokens; it0 += warp_size/neu_padded) {
|
|
||||||
const int it = it0 + threadIdx.x / neu_padded;
|
|
||||||
|
|
||||||
const int iex = threadIdx.x % neu_padded; // The index at which the expert is used, if any.
|
|
||||||
const int expert_used = (neu_padded == n_expert_used || iex < n_expert_used) && it < n_tokens ?
|
|
||||||
ids[it*si1 + iex] : INT_MAX;
|
|
||||||
const int iex_used = expert_used == expert ? iex : -1;
|
|
||||||
nex_prev += expert_used < expert;
|
|
||||||
|
|
||||||
// Whether the threads at this token position have used the expert:
|
|
||||||
const int it_compact_add_self = warp_reduce_any<neu_padded>(iex_used != -1);
|
|
||||||
|
|
||||||
// Do a scan over threads at lower token positions in warp to get the correct index for writing data:
|
|
||||||
int it_compact_add_lower = 0;
|
|
||||||
#pragma unroll
|
|
||||||
for (int offset = neu_padded; offset < warp_size; offset += neu_padded) {
|
|
||||||
const int tmp = __shfl_up_sync(0xFFFFFFFF, it_compact_add_self, offset, warp_size);
|
|
||||||
if (threadIdx.x >= static_cast<unsigned int>(offset)) {
|
|
||||||
it_compact_add_lower += tmp;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (iex_used != -1) {
|
|
||||||
store[it_compact + it_compact_add_lower] = mmq_ids_helper_store(it, iex_used);
|
|
||||||
}
|
|
||||||
|
|
||||||
// The thread with the highest index in the warp always has the sum over the whole warp, use it to increment all threads:
|
|
||||||
it_compact += __shfl_sync(0xFFFFFFFF, it_compact_add_lower + it_compact_add_self, warp_size - 1, warp_size);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
nex_prev = warp_reduce_sum<warp_size>(nex_prev);
|
|
||||||
|
|
||||||
for (int itc = threadIdx.x; itc < it_compact; itc += warp_size) {
|
|
||||||
const mmq_ids_helper_store store_it = store[itc];
|
|
||||||
const int it = store_it.it();
|
|
||||||
const int iex_used = store_it.iex_used();
|
|
||||||
ids_src1[nex_prev + itc] = it*sis1 + iex_used % nchannels_y;
|
|
||||||
ids_dst [nex_prev + itc] = it*n_expert_used + iex_used;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (threadIdx.x != 0) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
expert_bounds[expert] = nex_prev;
|
|
||||||
|
|
||||||
if (expert < static_cast<int>(gridDim.x) - 1) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
expert_bounds[gridDim.x] = nex_prev + it_compact;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <int n_expert_used_template>
|
|
||||||
static void launch_mmq_ids_helper(
|
|
||||||
const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
|
|
||||||
const int n_experts, const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) {
|
|
||||||
GGML_ASSERT(n_tokens < (1 << 22) && "too few bits in mmq_ids_helper_store");
|
|
||||||
GGML_ASSERT(n_expert_used_var < (1 << 10) && "too few bits in mmq_ids_helper_store");
|
|
||||||
|
|
||||||
const int id = ggml_cuda_get_device();
|
|
||||||
const int warp_size = ggml_cuda_info().devices[id].warp_size;
|
|
||||||
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
|
|
||||||
CUDA_SET_SHARED_MEMORY_LIMIT(mmq_ids_helper<n_expert_used_template>, smpbo);
|
|
||||||
|
|
||||||
const dim3 num_blocks(n_experts, 1, 1);
|
|
||||||
const dim3 block_size(warp_size, 1, 1);
|
|
||||||
const size_t nbytes_shared = n_tokens*sizeof(mmq_ids_helper_store);
|
|
||||||
GGML_ASSERT(nbytes_shared <= smpbo);
|
|
||||||
mmq_ids_helper<n_expert_used_template><<<num_blocks, block_size, nbytes_shared, stream>>>
|
|
||||||
(ids, ids_src1, ids_dst, expert_bounds, n_tokens, n_expert_used_var, nchannels_y, si1, sis1);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
|
static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
|
||||||
switch (args.type_x) {
|
switch (args.type_x) {
|
||||||
|
|
@ -293,36 +158,8 @@ void ggml_cuda_mul_mat_q(
|
||||||
const int si1 = ids->nb[1] / ggml_element_size(ids);
|
const int si1 = ids->nb[1] / ggml_element_size(ids);
|
||||||
const int sis1 = nb12 / nb11;
|
const int sis1 = nb12 / nb11;
|
||||||
|
|
||||||
switch (n_expert_used) {
|
ggml_cuda_launch_mm_ids_helper((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
|
||||||
case 2:
|
|
||||||
launch_mmq_ids_helper< 2> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
|
|
||||||
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
|
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
|
||||||
break;
|
|
||||||
case 4:
|
|
||||||
launch_mmq_ids_helper< 4> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
|
|
||||||
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
|
|
||||||
break;
|
|
||||||
case 6:
|
|
||||||
launch_mmq_ids_helper< 6> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
|
|
||||||
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
|
|
||||||
break;
|
|
||||||
case 8:
|
|
||||||
launch_mmq_ids_helper< 8> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
|
|
||||||
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
|
|
||||||
break;
|
|
||||||
case 16:
|
|
||||||
launch_mmq_ids_helper<16> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
|
|
||||||
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
|
|
||||||
break;
|
|
||||||
case 32:
|
|
||||||
launch_mmq_ids_helper<32> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
|
|
||||||
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
launch_mmq_ids_helper< 0> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
|
|
||||||
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
CUDA_CHECK(cudaGetLastError());
|
CUDA_CHECK(cudaGetLastError());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,14 +7,14 @@ template <typename T, typename type_acc, int ncols_dst, int block_size>
|
||||||
static __global__ void mul_mat_vec_f(
|
static __global__ void mul_mat_vec_f(
|
||||||
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
|
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
|
||||||
const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
|
const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
|
||||||
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
||||||
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
|
const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
|
||||||
const int row = blockIdx.x;
|
const int row = blockIdx.x;
|
||||||
const int channel_dst = blockIdx.y;
|
const int channel_dst = blockIdx.y;
|
||||||
const int channel_x = ids ? ids[channel_dst] : channel_dst / channel_ratio;
|
const int channel_x = ids ? ids[channel_dst] : fastdiv((uint32_t) channel_dst, channel_ratio);
|
||||||
const int channel_y = ids ? channel_dst % nchannels_y : channel_dst;
|
const int channel_y = ids ? channel_dst % nchannels_y : channel_dst;
|
||||||
const int sample_dst = blockIdx.z;
|
const int sample_dst = blockIdx.z;
|
||||||
const int sample_x = sample_dst / sample_ratio;
|
const int sample_x = fastdiv((uint32_t) sample_dst, sample_ratio);
|
||||||
const int sample_y = sample_dst;
|
const int sample_y = sample_dst;
|
||||||
const int tid = threadIdx.x;
|
const int tid = threadIdx.x;
|
||||||
|
|
||||||
|
|
@ -47,8 +47,8 @@ static __global__ void mul_mat_vec_f(
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < ncols_dst; ++j) {
|
for (int j = 0; j < ncols_dst; ++j) {
|
||||||
const float2 tmpy = y2[j*stride_col_y2 + col2];
|
const float2 tmpy = y2[j*stride_col_y2 + col2];
|
||||||
sumf[j] += tmpx.x*tmpy.x;
|
ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
|
||||||
sumf[j] += tmpx.y*tmpy.y;
|
ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if constexpr (std::is_same_v<T, half>) {
|
} else if constexpr (std::is_same_v<T, half>) {
|
||||||
|
|
@ -61,8 +61,8 @@ static __global__ void mul_mat_vec_f(
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < ncols_dst; ++j) {
|
for (int j = 0; j < ncols_dst; ++j) {
|
||||||
const float2 tmpy = y2[j*stride_col_y2 + col2];
|
const float2 tmpy = y2[j*stride_col_y2 + col2];
|
||||||
sumf[j] += tmpx.x * tmpy.x;
|
ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
|
||||||
sumf[j] += tmpx.y * tmpy.y;
|
ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -88,16 +88,32 @@ static __global__ void mul_mat_vec_f(
|
||||||
#endif // FP16_AVAILABLE
|
#endif // FP16_AVAILABLE
|
||||||
}
|
}
|
||||||
} else if constexpr (std::is_same_v<T, nv_bfloat16>) {
|
} else if constexpr (std::is_same_v<T, nv_bfloat16>) {
|
||||||
|
//TODO: add support for ggml_cuda_mad for hip_bfloat162
|
||||||
|
#if defined(GGML_USE_HIP)
|
||||||
const int * x2 = (const int *) x;
|
const int * x2 = (const int *) x;
|
||||||
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
|
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
|
||||||
const int tmpx = x2[col2];
|
const int tmpx = x2[col2];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < ncols_dst; ++j) {
|
for (int j = 0; j < ncols_dst; ++j) {
|
||||||
const float2 tmpy = y2[j*stride_col_y2 + col2];
|
const float2 tmpy = y2[j*stride_col_y2 + col2];
|
||||||
sumf[j] += ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]) * tmpy.x;
|
const float tmpx0 = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]);
|
||||||
sumf[j] += ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]) * tmpy.y;
|
const float tmpx1 = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]);
|
||||||
|
ggml_cuda_mad(sumf[j], tmpx0, tmpy.x);
|
||||||
|
ggml_cuda_mad(sumf[j], tmpx1, tmpy.y);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#else
|
||||||
|
const nv_bfloat162 * x2 = (const nv_bfloat162 *) x;
|
||||||
|
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
|
||||||
|
const nv_bfloat162 tmpx = x2[col2];
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < ncols_dst; ++j) {
|
||||||
|
const float2 tmpy = y2[j*stride_col_y2 + col2];
|
||||||
|
ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
|
||||||
|
ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
} else {
|
} else {
|
||||||
static_assert(std::is_same_v<T, void>, "unsupported type");
|
static_assert(std::is_same_v<T, void>, "unsupported type");
|
||||||
}
|
}
|
||||||
|
|
@ -140,8 +156,8 @@ static void launch_mul_mat_vec_f_cuda(
|
||||||
GGML_ASSERT(stride_col_y % 2 == 0);
|
GGML_ASSERT(stride_col_y % 2 == 0);
|
||||||
GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
|
GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
|
||||||
GGML_ASSERT( nsamples_dst % nsamples_x == 0);
|
GGML_ASSERT( nsamples_dst % nsamples_x == 0);
|
||||||
const int64_t channel_ratio = nchannels_dst / nchannels_x;
|
const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0) : init_fastdiv_values(nchannels_dst / nchannels_x);
|
||||||
const int64_t sample_ratio = nsamples_dst / nsamples_x;
|
const uint3 sample_ratio_fd = init_fastdiv_values(nsamples_dst / nsamples_x);
|
||||||
|
|
||||||
const int device = ggml_cuda_get_device();
|
const int device = ggml_cuda_get_device();
|
||||||
const int warp_size = ggml_cuda_info().devices[device].warp_size;
|
const int warp_size = ggml_cuda_info().devices[device].warp_size;
|
||||||
|
|
@ -167,50 +183,50 @@ static void launch_mul_mat_vec_f_cuda(
|
||||||
case 32: {
|
case 32: {
|
||||||
mul_mat_vec_f<T, type_acc, ncols_dst, 32><<<block_nums, block_dims, nbytes_shared, stream>>>
|
mul_mat_vec_f<T, type_acc, ncols_dst, 32><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||||
} break;
|
} break;
|
||||||
case 64: {
|
case 64: {
|
||||||
mul_mat_vec_f<T, type_acc, ncols_dst, 64><<<block_nums, block_dims, nbytes_shared, stream>>>
|
mul_mat_vec_f<T, type_acc, ncols_dst, 64><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||||
} break;
|
} break;
|
||||||
case 96: {
|
case 96: {
|
||||||
mul_mat_vec_f<T, type_acc, ncols_dst, 96><<<block_nums, block_dims, nbytes_shared, stream>>>
|
mul_mat_vec_f<T, type_acc, ncols_dst, 96><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||||
} break;
|
} break;
|
||||||
case 128: {
|
case 128: {
|
||||||
mul_mat_vec_f<T, type_acc, ncols_dst, 128><<<block_nums, block_dims, nbytes_shared, stream>>>
|
mul_mat_vec_f<T, type_acc, ncols_dst, 128><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||||
} break;
|
} break;
|
||||||
case 160: {
|
case 160: {
|
||||||
mul_mat_vec_f<T, type_acc, ncols_dst, 160><<<block_nums, block_dims, nbytes_shared, stream>>>
|
mul_mat_vec_f<T, type_acc, ncols_dst, 160><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||||
} break;
|
} break;
|
||||||
case 192: {
|
case 192: {
|
||||||
mul_mat_vec_f<T, type_acc, ncols_dst, 192><<<block_nums, block_dims, nbytes_shared, stream>>>
|
mul_mat_vec_f<T, type_acc, ncols_dst, 192><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||||
} break;
|
} break;
|
||||||
case 224: {
|
case 224: {
|
||||||
mul_mat_vec_f<T, type_acc, ncols_dst, 224><<<block_nums, block_dims, nbytes_shared, stream>>>
|
mul_mat_vec_f<T, type_acc, ncols_dst, 224><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||||
} break;
|
} break;
|
||||||
case 256: {
|
case 256: {
|
||||||
mul_mat_vec_f<T, type_acc, ncols_dst, 256><<<block_nums, block_dims, nbytes_shared, stream>>>
|
mul_mat_vec_f<T, type_acc, ncols_dst, 256><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||||
} break;
|
} break;
|
||||||
default: {
|
default: {
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
|
|
|
||||||
|
|
@ -4,16 +4,61 @@
|
||||||
|
|
||||||
#include <initializer_list>
|
#include <initializer_list>
|
||||||
|
|
||||||
|
// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path.
|
||||||
|
template <int experts_per_thread, bool use_limit>
|
||||||
|
__device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const int limit, const int lane) {
|
||||||
|
float max_val = -INFINITY;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < experts_per_thread; i++) {
|
||||||
|
const int idx = lane + i * WARP_SIZE;
|
||||||
|
const bool active = !use_limit || (idx < limit);
|
||||||
|
if (active) {
|
||||||
|
max_val = max(max_val, vals[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
max_val = warp_reduce_max(max_val);
|
||||||
|
|
||||||
|
float sum = 0.f;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < experts_per_thread; i++) {
|
||||||
|
const int idx = lane + i * WARP_SIZE;
|
||||||
|
const bool active = !use_limit || (idx < limit);
|
||||||
|
if (active) {
|
||||||
|
const float val = expf(vals[i] - max_val);
|
||||||
|
vals[i] = val;
|
||||||
|
sum += val;
|
||||||
|
} else {
|
||||||
|
vals[i] = 0.f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sum = warp_reduce_sum(sum);
|
||||||
|
|
||||||
|
const float inv_sum = 1.0f / sum;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < experts_per_thread; i++) {
|
||||||
|
const int idx = lane + i * WARP_SIZE;
|
||||||
|
const bool active = !use_limit || (idx < limit);
|
||||||
|
if (active) {
|
||||||
|
vals[i] *= inv_sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
This kernel does the following:
|
This kernel does the following:
|
||||||
1. softmax over the logits per token [n_experts, n_tokens]
|
1. optionally softmax over the logits per token [n_experts, n_tokens]
|
||||||
2. argmax reduce over the top-k (n_experts_used) logits
|
2. argmax reduce over the top-k (n_experts_used) logits
|
||||||
3. write weights + ids to global memory
|
3. write weights + ids to global memory
|
||||||
4. optionally normalize the weights
|
4. optionally normalize the weights or apply softmax over the selected logits
|
||||||
|
|
||||||
It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models
|
It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models
|
||||||
*/
|
*/
|
||||||
template <int n_experts, bool with_norm>
|
template <int n_experts, bool with_norm, bool delayed_softmax = false>
|
||||||
__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits,
|
__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits,
|
||||||
float * weights,
|
float * weights,
|
||||||
int32_t * ids,
|
int32_t * ids,
|
||||||
|
|
@ -30,51 +75,30 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
|
||||||
|
|
||||||
constexpr int experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1;
|
constexpr int experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1;
|
||||||
|
|
||||||
float logits_r[experts_per_thread];
|
float wt[experts_per_thread];
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < n_experts; i += WARP_SIZE) {
|
for (int i = 0; i < n_experts; i += WARP_SIZE) {
|
||||||
const int expert = i + threadIdx.x;
|
const int expert = i + threadIdx.x;
|
||||||
logits_r[i / WARP_SIZE] = n_experts % WARP_SIZE == 0 || expert < n_experts ? logits[expert] : -INFINITY;
|
wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[expert] : -INFINITY;
|
||||||
}
|
}
|
||||||
|
|
||||||
float max_val = logits_r[0];
|
if constexpr (!delayed_softmax) {
|
||||||
|
softmax_warp_inplace<experts_per_thread, false>(wt, n_experts, threadIdx.x);
|
||||||
#pragma unroll
|
|
||||||
for (int i = 1; i < experts_per_thread; i++) {
|
|
||||||
const float val = logits_r[i];
|
|
||||||
max_val = max(val, max_val);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
max_val = warp_reduce_max(max_val);
|
//at this point, each thread holds either a portion of the softmax distribution
|
||||||
|
//or the raw logits. We do the argmax reduce over n_expert_used, each time marking
|
||||||
float wt[experts_per_thread];
|
|
||||||
float tmp = 0.f;
|
|
||||||
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < experts_per_thread; i++) {
|
|
||||||
const float val = logits_r[i];
|
|
||||||
wt[i] = expf(val - max_val);
|
|
||||||
tmp += wt[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
tmp = warp_reduce_sum(tmp);
|
|
||||||
|
|
||||||
const float inv_sum = 1.0f / tmp;
|
|
||||||
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < experts_per_thread; i++) {
|
|
||||||
wt[i] = wt[i] * inv_sum;
|
|
||||||
}
|
|
||||||
|
|
||||||
//at this point, each thread holds a portion of softmax,
|
|
||||||
//we do the argmax reduce over n_expert_used, each time marking
|
|
||||||
//the expert weight as -inf to exclude from the next iteration
|
//the expert weight as -inf to exclude from the next iteration
|
||||||
|
|
||||||
float wt_sum = 0.f;
|
float wt_sum = 0.f;
|
||||||
|
|
||||||
extern __shared__ float data_topk_shared[];
|
float output_weights[experts_per_thread];
|
||||||
float * wt_shared_ptr = data_topk_shared + threadIdx.y * n_expert_used;
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < experts_per_thread; i++) {
|
||||||
|
output_weights[i] = 0.f;
|
||||||
|
}
|
||||||
|
|
||||||
for (int k = 0; k < n_expert_used; k++) {
|
for (int k = 0; k < n_expert_used; k++) {
|
||||||
float max_val = wt[0];
|
float max_val = wt[0];
|
||||||
|
|
@ -99,10 +123,13 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if ((k & (WARP_SIZE - 1)) == threadIdx.x) {
|
||||||
|
output_weights[k / WARP_SIZE] = max_val;
|
||||||
|
}
|
||||||
|
|
||||||
if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
|
if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
|
||||||
wt[max_expert / WARP_SIZE] = -INFINITY;
|
wt[max_expert / WARP_SIZE] = -INFINITY;
|
||||||
|
|
||||||
wt_shared_ptr[k] = max_val;
|
|
||||||
ids[k] = max_expert;
|
ids[k] = max_expert;
|
||||||
if constexpr (with_norm) {
|
if constexpr (with_norm) {
|
||||||
wt_sum += max_val;
|
wt_sum += max_val;
|
||||||
|
|
@ -114,17 +141,25 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
|
||||||
wt_sum = warp_reduce_sum(wt_sum);
|
wt_sum = warp_reduce_sum(wt_sum);
|
||||||
const float inv_sum = 1.0f / wt_sum;
|
const float inv_sum = 1.0f / wt_sum;
|
||||||
|
|
||||||
for (int i = threadIdx.x; i < n_expert_used; i += WARP_SIZE) {
|
for (int i = 0; i < experts_per_thread; i++) {
|
||||||
wt_shared_ptr[i] = wt_shared_ptr[i] * inv_sum;
|
output_weights[i] *= inv_sum;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = threadIdx.x; i < n_expert_used; i += WARP_SIZE) {
|
if constexpr (delayed_softmax) {
|
||||||
weights[i] = wt_shared_ptr[i];
|
softmax_warp_inplace<experts_per_thread, true>(output_weights, n_expert_used, threadIdx.x);
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < experts_per_thread; i++) {
|
||||||
|
const int idx = i * WARP_SIZE + threadIdx.x;
|
||||||
|
if (idx < n_expert_used) {
|
||||||
|
weights[idx] = output_weights[i];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <bool with_norm>
|
template <bool with_norm, bool delayed_softmax = false>
|
||||||
static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
|
static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
|
||||||
const float * logits,
|
const float * logits,
|
||||||
float * weights,
|
float * weights,
|
||||||
|
|
@ -132,53 +167,53 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
|
||||||
const int n_rows,
|
const int n_rows,
|
||||||
const int n_expert,
|
const int n_expert,
|
||||||
const int n_expert_used) {
|
const int n_expert_used) {
|
||||||
|
static_assert(!(with_norm && delayed_softmax), "delayed softmax is not supported with weight normalization");
|
||||||
|
|
||||||
const int rows_per_block = 4;
|
const int rows_per_block = 4;
|
||||||
dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1);
|
dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1);
|
||||||
dim3 block_dims(WARP_SIZE, rows_per_block, 1);
|
dim3 block_dims(WARP_SIZE, rows_per_block, 1);
|
||||||
cudaStream_t stream = ctx.stream();
|
cudaStream_t stream = ctx.stream();
|
||||||
|
|
||||||
const int nbytes_shared = n_expert_used * rows_per_block * sizeof(float);
|
|
||||||
|
|
||||||
switch (n_expert) {
|
switch (n_expert) {
|
||||||
case 1:
|
case 1:
|
||||||
topk_moe_cuda<1, with_norm>
|
topk_moe_cuda<1, with_norm, delayed_softmax>
|
||||||
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||||
break;
|
break;
|
||||||
case 2:
|
case 2:
|
||||||
topk_moe_cuda<2, with_norm>
|
topk_moe_cuda<2, with_norm, delayed_softmax>
|
||||||
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||||
break;
|
break;
|
||||||
case 4:
|
case 4:
|
||||||
topk_moe_cuda<4, with_norm>
|
topk_moe_cuda<4, with_norm, delayed_softmax>
|
||||||
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||||
break;
|
break;
|
||||||
case 8:
|
case 8:
|
||||||
topk_moe_cuda<8, with_norm>
|
topk_moe_cuda<8, with_norm, delayed_softmax>
|
||||||
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||||
break;
|
break;
|
||||||
case 16:
|
case 16:
|
||||||
topk_moe_cuda<16, with_norm>
|
topk_moe_cuda<16, with_norm, delayed_softmax>
|
||||||
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||||
break;
|
break;
|
||||||
case 32:
|
case 32:
|
||||||
topk_moe_cuda<32, with_norm>
|
topk_moe_cuda<32, with_norm, delayed_softmax>
|
||||||
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||||
break;
|
break;
|
||||||
case 64:
|
case 64:
|
||||||
topk_moe_cuda<64, with_norm>
|
topk_moe_cuda<64, with_norm, delayed_softmax>
|
||||||
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||||
break;
|
break;
|
||||||
case 128:
|
case 128:
|
||||||
topk_moe_cuda<128, with_norm>
|
topk_moe_cuda<128, with_norm, delayed_softmax>
|
||||||
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||||
break;
|
break;
|
||||||
case 256:
|
case 256:
|
||||||
topk_moe_cuda<256, with_norm>
|
topk_moe_cuda<256, with_norm, delayed_softmax>
|
||||||
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||||
break;
|
break;
|
||||||
case 512:
|
case 512:
|
||||||
topk_moe_cuda<512, with_norm>
|
topk_moe_cuda<512, with_norm, delayed_softmax>
|
||||||
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
GGML_ASSERT(false && "fatal error");
|
GGML_ASSERT(false && "fatal error");
|
||||||
|
|
@ -190,7 +225,8 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
|
||||||
const ggml_tensor * logits,
|
const ggml_tensor * logits,
|
||||||
ggml_tensor * weights,
|
ggml_tensor * weights,
|
||||||
ggml_tensor * ids,
|
ggml_tensor * ids,
|
||||||
const bool with_norm) {
|
const bool with_norm,
|
||||||
|
const bool delayed_softmax) {
|
||||||
GGML_ASSERT(logits->type == GGML_TYPE_F32);
|
GGML_ASSERT(logits->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT(weights->type == GGML_TYPE_F32);
|
GGML_ASSERT(weights->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT(ids->type == GGML_TYPE_I32);
|
GGML_ASSERT(ids->type == GGML_TYPE_I32);
|
||||||
|
|
@ -198,7 +234,7 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
|
||||||
const int n_experts = logits->ne[0];
|
const int n_experts = logits->ne[0];
|
||||||
const int n_rows = logits->ne[1];
|
const int n_rows = logits->ne[1];
|
||||||
|
|
||||||
const float * logits_d = (const float *) logits->src[0]->data;
|
const float * logits_d = (const float *) logits->data;
|
||||||
float * weights_d = (float *) weights->data;
|
float * weights_d = (float *) weights->data;
|
||||||
int32_t * ids_d = (int32_t *) ids->data;
|
int32_t * ids_d = (int32_t *) ids->data;
|
||||||
|
|
||||||
|
|
@ -209,7 +245,11 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
|
||||||
if (with_norm) {
|
if (with_norm) {
|
||||||
launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
|
launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
|
||||||
} else {
|
} else {
|
||||||
launch_topk_moe_cuda<false>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
|
if (delayed_softmax) {
|
||||||
|
launch_topk_moe_cuda<false, true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
|
||||||
|
} else {
|
||||||
|
launch_topk_moe_cuda<false, false>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -242,7 +282,7 @@ bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tenso
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool norm) {
|
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool norm, bool delayed_softmax) {
|
||||||
static std::initializer_list<enum ggml_op> norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
|
static std::initializer_list<enum ggml_op> norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
|
||||||
GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
|
GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
|
||||||
GGML_OP_SUM_ROWS, GGML_OP_DIV, GGML_OP_RESHAPE };
|
GGML_OP_SUM_ROWS, GGML_OP_DIV, GGML_OP_RESHAPE };
|
||||||
|
|
@ -250,8 +290,19 @@ std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool norm) {
|
||||||
static std::initializer_list<enum ggml_op> no_norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
|
static std::initializer_list<enum ggml_op> no_norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
|
||||||
GGML_OP_VIEW, GGML_OP_GET_ROWS };
|
GGML_OP_VIEW, GGML_OP_GET_ROWS };
|
||||||
|
|
||||||
|
static std::initializer_list<enum ggml_op> delayed_softmax_ops = { GGML_OP_ARGSORT, GGML_OP_VIEW,
|
||||||
|
GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
|
||||||
|
GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
|
||||||
|
|
||||||
|
GGML_ASSERT(!norm || !delayed_softmax);
|
||||||
|
|
||||||
|
if (delayed_softmax) {
|
||||||
|
return delayed_softmax_ops;
|
||||||
|
}
|
||||||
|
|
||||||
if (norm) {
|
if (norm) {
|
||||||
return norm_ops;
|
return norm_ops;
|
||||||
}
|
}
|
||||||
|
|
||||||
return no_norm_ops;
|
return no_norm_ops;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -6,9 +6,10 @@
|
||||||
void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
|
void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
|
||||||
const ggml_tensor * logits,
|
const ggml_tensor * logits,
|
||||||
ggml_tensor * weights,
|
ggml_tensor * weights,
|
||||||
ggml_tensor * top_k,
|
ggml_tensor * ids,
|
||||||
const bool with_norm);
|
const bool with_norm,
|
||||||
|
const bool delayed_softmax = false);
|
||||||
|
|
||||||
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights);
|
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights);
|
||||||
|
|
||||||
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool with_norm);
|
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool with_norm, bool delayed_softmax = false);
|
||||||
|
|
|
||||||
|
|
@ -28,8 +28,10 @@ if (CXX_IS_HIPCC)
|
||||||
" Prefer setting the HIP compiler directly. See README for details.")
|
" Prefer setting the HIP compiler directly. See README for details.")
|
||||||
endif()
|
endif()
|
||||||
else()
|
else()
|
||||||
# Forward AMDGPU_TARGETS to CMAKE_HIP_ARCHITECTURES.
|
# Forward (AMD)GPU_TARGETS to CMAKE_HIP_ARCHITECTURES.
|
||||||
if (AMDGPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES)
|
if(GPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES)
|
||||||
|
set(CMAKE_HIP_ARCHITECTURES ${GPU_TARGETS})
|
||||||
|
elseif(AMDGPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES)
|
||||||
set(CMAKE_HIP_ARCHITECTURES ${AMDGPU_TARGETS})
|
set(CMAKE_HIP_ARCHITECTURES ${AMDGPU_TARGETS})
|
||||||
endif()
|
endif()
|
||||||
cmake_minimum_required(VERSION 3.21)
|
cmake_minimum_required(VERSION 3.21)
|
||||||
|
|
|
||||||
|
|
@ -565,14 +565,23 @@ static inline ggml_bf16_t ggml_compute_fp32_to_bf16(float s) {
|
||||||
#define GGML_FP32_TO_BF16(x) ggml_compute_fp32_to_bf16(x)
|
#define GGML_FP32_TO_BF16(x) ggml_compute_fp32_to_bf16(x)
|
||||||
#define GGML_BF16_TO_FP32(x) ggml_compute_bf16_to_fp32(x)
|
#define GGML_BF16_TO_FP32(x) ggml_compute_bf16_to_fp32(x)
|
||||||
|
|
||||||
|
static inline int32_t ggml_node_get_use_count(const struct ggml_cgraph * cgraph, int node_idx) {
|
||||||
|
const struct ggml_tensor * node = cgraph->nodes[node_idx];
|
||||||
|
|
||||||
|
size_t hash_pos = ggml_hash_find(&cgraph->visited_hash_set, node);
|
||||||
|
if (!ggml_bitset_get(cgraph->visited_hash_set.used, hash_pos)) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
return cgraph->use_counts[hash_pos];
|
||||||
|
}
|
||||||
|
|
||||||
// return true if the node's results are only used by N other nodes
|
// return true if the node's results are only used by N other nodes
|
||||||
// and can be fused into their calculations.
|
// and can be fused into their calculations.
|
||||||
static inline bool ggml_node_has_n_uses(const struct ggml_cgraph * cgraph, int node_idx, int32_t n_uses) {
|
static inline bool ggml_node_has_n_uses(const struct ggml_cgraph * cgraph, int node_idx, int32_t n_uses) {
|
||||||
const struct ggml_tensor * node = cgraph->nodes[node_idx];
|
const struct ggml_tensor * node = cgraph->nodes[node_idx];
|
||||||
|
|
||||||
// check the use count against how many we're replacing
|
// check the use count against how many we're replacing
|
||||||
size_t hash_pos = ggml_hash_find(&cgraph->visited_hash_set, node);
|
if (ggml_node_get_use_count(cgraph, node_idx) != n_uses) {
|
||||||
if (!ggml_bitset_get(cgraph->visited_hash_set.used, hash_pos) || cgraph->use_counts[hash_pos] != n_uses) {
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -638,6 +647,36 @@ static inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx
|
||||||
return ggml_can_fuse_ext(cgraph, idxs, ops, num_ops);
|
return ggml_can_fuse_ext(cgraph, idxs, ops, num_ops);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
GGML_API bool ggml_can_fuse_subgraph_ext(const struct ggml_cgraph * cgraph,
|
||||||
|
const int * node_idxs,
|
||||||
|
int count,
|
||||||
|
const enum ggml_op * ops,
|
||||||
|
const int * outputs,
|
||||||
|
int num_outputs);
|
||||||
|
|
||||||
|
// Returns true if the subgraph formed by {node_idxs} can be fused
|
||||||
|
// checks whethers all nodes which are not part of outputs can be elided
|
||||||
|
// by checking if their num_uses are confined to the subgraph
|
||||||
|
static inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph,
|
||||||
|
int node_idx,
|
||||||
|
int count,
|
||||||
|
const enum ggml_op * ops,
|
||||||
|
const int * outputs,
|
||||||
|
int num_outputs) {
|
||||||
|
GGML_ASSERT(count < 32);
|
||||||
|
if (node_idx + count > cgraph->n_nodes) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
int idxs[32];
|
||||||
|
|
||||||
|
for (int i = 0; i < count; ++i) {
|
||||||
|
idxs[i] = node_idx + i;
|
||||||
|
}
|
||||||
|
|
||||||
|
return ggml_can_fuse_subgraph_ext(cgraph, idxs, count, ops, outputs, num_outputs);
|
||||||
|
}
|
||||||
|
|
||||||
// Management libraries for fetching more accurate free VRAM data
|
// Management libraries for fetching more accurate free VRAM data
|
||||||
GGML_API int ggml_nvml_init();
|
GGML_API int ggml_nvml_init();
|
||||||
GGML_API int ggml_nvml_get_device_memory(const char *uuid, size_t *free, size_t *total);
|
GGML_API int ggml_nvml_get_device_memory(const char *uuid, size_t *free, size_t *total);
|
||||||
|
|
@ -662,6 +701,13 @@ inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::
|
||||||
return ggml_can_fuse(cgraph, node_idx, ops.begin(), (int)ops.size());
|
return ggml_can_fuse(cgraph, node_idx, ops.begin(), (int)ops.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph,
|
||||||
|
int start_idx,
|
||||||
|
std::initializer_list<enum ggml_op> ops,
|
||||||
|
std::initializer_list<int> outputs = {}) {
|
||||||
|
return ggml_can_fuse_subgraph(cgraph, start_idx, ops.size(), ops.begin(), outputs.begin(), outputs.size());
|
||||||
|
}
|
||||||
|
|
||||||
// expose GGUF internals for test code
|
// expose GGUF internals for test code
|
||||||
GGML_API size_t gguf_type_size(enum gguf_type type);
|
GGML_API size_t gguf_type_size(enum gguf_type type);
|
||||||
GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params);
|
GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params);
|
||||||
|
|
|
||||||
|
|
@ -1406,6 +1406,31 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d(ggml_met
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_2d(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||||
|
assert(op->op == GGML_OP_CONV_TRANSPOSE_2D);
|
||||||
|
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(op->src[1]));
|
||||||
|
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(op->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
|
char base[256];
|
||||||
|
char name[256];
|
||||||
|
|
||||||
|
snprintf(base, 256, "kernel_conv_transpose_2d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type));
|
||||||
|
snprintf(name, 256, "%s", base);
|
||||||
|
|
||||||
|
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||||
|
if (res) {
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_upscale(ggml_metal_library_t lib, const ggml_tensor * op) {
|
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_upscale(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||||
assert(op->op == GGML_OP_UPSCALE);
|
assert(op->op == GGML_OP_UPSCALE);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -130,6 +130,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm (ggml_me
|
||||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_im2col (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_im2col (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||||
|
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_2d (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_upscale (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_upscale (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,8 @@
|
||||||
|
|
||||||
#include <Metal/Metal.h>
|
#include <Metal/Metal.h>
|
||||||
|
|
||||||
|
#include <stdatomic.h>
|
||||||
|
|
||||||
#ifndef TARGET_OS_VISION
|
#ifndef TARGET_OS_VISION
|
||||||
#define TARGET_OS_VISION 0
|
#define TARGET_OS_VISION 0
|
||||||
#endif
|
#endif
|
||||||
|
|
@ -22,6 +24,9 @@
|
||||||
// overload of MTLGPUFamilyMetal3 (not available in some environments)
|
// overload of MTLGPUFamilyMetal3 (not available in some environments)
|
||||||
static const NSInteger MTLGPUFamilyMetal3_GGML = 5001;
|
static const NSInteger MTLGPUFamilyMetal3_GGML = 5001;
|
||||||
|
|
||||||
|
// virtual address for GPU memory allocations
|
||||||
|
static atomic_uintptr_t g_addr_device = 0x000000400ULL;
|
||||||
|
|
||||||
#if !GGML_METAL_EMBED_LIBRARY
|
#if !GGML_METAL_EMBED_LIBRARY
|
||||||
// Here to assist with NSBundle Path Hack
|
// Here to assist with NSBundle Path Hack
|
||||||
@interface GGMLMetalClass : NSObject
|
@interface GGMLMetalClass : NSObject
|
||||||
|
|
@ -648,6 +653,11 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||||
case GGML_OP_SCALE:
|
case GGML_OP_SCALE:
|
||||||
case GGML_OP_CONV_TRANSPOSE_1D:
|
case GGML_OP_CONV_TRANSPOSE_1D:
|
||||||
return true;
|
return true;
|
||||||
|
case GGML_OP_CONV_TRANSPOSE_2D:
|
||||||
|
return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]) &&
|
||||||
|
(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32) &&
|
||||||
|
op->src[1]->type == GGML_TYPE_F32 &&
|
||||||
|
op->type == GGML_TYPE_F32;
|
||||||
case GGML_OP_CLAMP:
|
case GGML_OP_CLAMP:
|
||||||
return op->src[0]->type == GGML_TYPE_F32;
|
return op->src[0]->type == GGML_TYPE_F32;
|
||||||
case GGML_OP_SQR:
|
case GGML_OP_SQR:
|
||||||
|
|
@ -657,6 +667,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||||
case GGML_OP_LOG:
|
case GGML_OP_LOG:
|
||||||
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
||||||
case GGML_OP_SUM:
|
case GGML_OP_SUM:
|
||||||
|
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
|
||||||
case GGML_OP_SUM_ROWS:
|
case GGML_OP_SUM_ROWS:
|
||||||
case GGML_OP_MEAN:
|
case GGML_OP_MEAN:
|
||||||
case GGML_OP_SOFT_MAX:
|
case GGML_OP_SOFT_MAX:
|
||||||
|
|
@ -693,7 +704,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||||
return true;
|
return true;
|
||||||
case GGML_OP_FLASH_ATTN_EXT:
|
case GGML_OP_FLASH_ATTN_EXT:
|
||||||
// for new head sizes, add checks here
|
// for new head sizes, add checks here
|
||||||
if (op->src[0]->ne[0] != 40 &&
|
if (op->src[0]->ne[0] != 32 &&
|
||||||
|
op->src[0]->ne[0] != 40 &&
|
||||||
op->src[0]->ne[0] != 64 &&
|
op->src[0]->ne[0] != 64 &&
|
||||||
op->src[0]->ne[0] != 80 &&
|
op->src[0]->ne[0] != 80 &&
|
||||||
op->src[0]->ne[0] != 96 &&
|
op->src[0]->ne[0] != 96 &&
|
||||||
|
|
@ -826,7 +838,7 @@ struct ggml_metal_buffer_wrapper {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ggml_metal_buffer {
|
struct ggml_metal_buffer {
|
||||||
void * all_data; // TODO: https://github.com/ggml-org/llama.cpp/pull/15985
|
void * all_data;
|
||||||
size_t all_size;
|
size_t all_size;
|
||||||
|
|
||||||
// if false, the Metal buffer data is allocated in private GPU memory and is not shared with the host
|
// if false, the Metal buffer data is allocated in private GPU memory and is not shared with the host
|
||||||
|
|
@ -964,14 +976,15 @@ ggml_metal_buffer_t ggml_metal_buffer_init(ggml_metal_device_t dev, size_t size,
|
||||||
if (shared) {
|
if (shared) {
|
||||||
res->all_data = ggml_metal_host_malloc(size_aligned);
|
res->all_data = ggml_metal_host_malloc(size_aligned);
|
||||||
res->is_shared = true;
|
res->is_shared = true;
|
||||||
res->owned = true;
|
|
||||||
} else {
|
} else {
|
||||||
// dummy, non-NULL value - we'll populate this after creating the Metal buffer below
|
// use virtual address from g_addr_device counter
|
||||||
res->all_data = (void *) 0x000000400ULL;
|
res->all_data = (void *) atomic_fetch_add_explicit(&g_addr_device, size_aligned, memory_order_relaxed);
|
||||||
res->is_shared = false;
|
res->is_shared = false;
|
||||||
}
|
}
|
||||||
res->all_size = size_aligned;
|
res->all_size = size_aligned;
|
||||||
|
|
||||||
|
res->owned = true;
|
||||||
|
|
||||||
res->device = ggml_metal_device_get_obj(dev);
|
res->device = ggml_metal_device_get_obj(dev);
|
||||||
res->queue = ggml_metal_device_get_queue(dev);
|
res->queue = ggml_metal_device_get_queue(dev);
|
||||||
|
|
||||||
|
|
@ -982,15 +995,13 @@ ggml_metal_buffer_t ggml_metal_buffer_init(ggml_metal_device_t dev, size_t size,
|
||||||
res->buffers[0].metal = nil;
|
res->buffers[0].metal = nil;
|
||||||
|
|
||||||
if (size_aligned > 0) {
|
if (size_aligned > 0) {
|
||||||
if (props_dev->use_shared_buffers &&shared) {
|
if (props_dev->use_shared_buffers && shared) {
|
||||||
res->buffers[0].metal = [res->device newBufferWithBytesNoCopy:res->all_data
|
res->buffers[0].metal = [res->device newBufferWithBytesNoCopy:res->all_data
|
||||||
length:size_aligned
|
length:size_aligned
|
||||||
options:MTLResourceStorageModeShared
|
options:MTLResourceStorageModeShared
|
||||||
deallocator:nil];
|
deallocator:nil];
|
||||||
} else {
|
} else {
|
||||||
res->buffers[0].metal = [res->device newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate];
|
res->buffers[0].metal = [res->device newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate];
|
||||||
|
|
||||||
res->all_data = (void *) (res->buffers[0].metal.gpuAddress);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1138,7 +1149,7 @@ bool ggml_metal_buffer_is_shared(ggml_metal_buffer_t buf) {
|
||||||
|
|
||||||
void ggml_metal_buffer_memset_tensor(ggml_metal_buffer_t buf, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
|
void ggml_metal_buffer_memset_tensor(ggml_metal_buffer_t buf, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
|
||||||
if (buf->is_shared) {
|
if (buf->is_shared) {
|
||||||
memset((char *)tensor->data + offset, value, size);
|
memset((char *) tensor->data + offset, value, size);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1167,7 +1178,7 @@ void ggml_metal_buffer_memset_tensor(ggml_metal_buffer_t buf, struct ggml_tensor
|
||||||
|
|
||||||
void ggml_metal_buffer_set_tensor(ggml_metal_buffer_t buf, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
void ggml_metal_buffer_set_tensor(ggml_metal_buffer_t buf, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
||||||
if (buf->is_shared) {
|
if (buf->is_shared) {
|
||||||
memcpy((char *)tensor->data + offset, data, size);
|
memcpy((char *) tensor->data + offset, data, size);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1222,7 +1233,7 @@ void ggml_metal_buffer_set_tensor(ggml_metal_buffer_t buf, struct ggml_tensor *
|
||||||
|
|
||||||
void ggml_metal_buffer_get_tensor(ggml_metal_buffer_t buf, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
void ggml_metal_buffer_get_tensor(ggml_metal_buffer_t buf, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
||||||
if (buf->is_shared) {
|
if (buf->is_shared) {
|
||||||
memcpy(data, (const char *)tensor->data + offset, size);
|
memcpy(data, (const char *) tensor->data + offset, size);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2136,6 +2136,7 @@ typedef struct {
|
||||||
int32_t sect_1;
|
int32_t sect_1;
|
||||||
int32_t sect_2;
|
int32_t sect_2;
|
||||||
int32_t sect_3;
|
int32_t sect_3;
|
||||||
|
bool src2;
|
||||||
} ggml_metal_kargs_rope;
|
} ggml_metal_kargs_rope;
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
|
|
@ -2398,6 +2399,19 @@ typedef struct {
|
||||||
uint64_t nb1;
|
uint64_t nb1;
|
||||||
} ggml_metal_kargs_conv_transpose_1d;
|
} ggml_metal_kargs_conv_transpose_1d;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
int32_t IC;
|
||||||
|
int32_t IH;
|
||||||
|
int32_t IW;
|
||||||
|
int32_t KH;
|
||||||
|
int32_t KW;
|
||||||
|
int32_t OC;
|
||||||
|
int32_t s0;
|
||||||
|
uint64_t nb0;
|
||||||
|
uint64_t nb1;
|
||||||
|
uint64_t nb2;
|
||||||
|
} ggml_metal_kargs_conv_transpose_2d;
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
uint64_t ofs0;
|
uint64_t ofs0;
|
||||||
uint64_t ofs1;
|
uint64_t ofs1;
|
||||||
|
|
@ -4392,18 +4406,48 @@ kernel void kernel_op_sum_f32(
|
||||||
constant ggml_metal_kargs_sum & args,
|
constant ggml_metal_kargs_sum & args,
|
||||||
device const float * src0,
|
device const float * src0,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
ushort tiitg[[thread_index_in_threadgroup]]) {
|
threadgroup float * shmem_f32 [[threadgroup(0)]],
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||||
|
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||||
|
ushort tiisg[[thread_index_in_simdgroup]],
|
||||||
|
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||||
|
|
||||||
if (tiitg != 0) {
|
if (args.np == 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
float acc = 0.0f;
|
const uint nsg = (ntg.x + 31) / 32;
|
||||||
for (ulong i = 0; i < args.np; ++i) {
|
|
||||||
acc += src0[i];
|
float sumf = 0;
|
||||||
|
|
||||||
|
for (int64_t i0 = tpitg.x; i0 < args.np; i0 += ntg.x) {
|
||||||
|
sumf += src0[i0];
|
||||||
}
|
}
|
||||||
|
|
||||||
dst[0] = acc;
|
sumf = simd_sum(sumf);
|
||||||
|
|
||||||
|
if (tiisg == 0) {
|
||||||
|
shmem_f32[sgitg] = sumf;
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
float total = 0;
|
||||||
|
|
||||||
|
if (sgitg == 0) {
|
||||||
|
float v = 0;
|
||||||
|
|
||||||
|
if (tpitg.x < nsg) {
|
||||||
|
v = shmem_f32[tpitg.x];
|
||||||
|
}
|
||||||
|
|
||||||
|
total = simd_sum(v);
|
||||||
|
|
||||||
|
if (tpitg.x == 0) {
|
||||||
|
dst[0] = total;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <bool norm>
|
template <bool norm>
|
||||||
|
|
@ -6413,7 +6457,7 @@ kernel void kernel_rope_norm(
|
||||||
|
|
||||||
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
|
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
|
||||||
|
|
||||||
const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
|
const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
|
||||||
|
|
||||||
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
|
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
|
||||||
|
|
||||||
|
|
@ -6466,7 +6510,7 @@ kernel void kernel_rope_neox(
|
||||||
|
|
||||||
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
|
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
|
||||||
|
|
||||||
const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
|
const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
|
||||||
|
|
||||||
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
|
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
|
||||||
|
|
||||||
|
|
@ -6533,7 +6577,7 @@ kernel void kernel_rope_multi(
|
||||||
|
|
||||||
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
|
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
|
||||||
|
|
||||||
const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
|
const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
|
||||||
|
|
||||||
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
|
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
|
||||||
|
|
||||||
|
|
@ -6600,7 +6644,7 @@ kernel void kernel_rope_vision(
|
||||||
const float theta = theta_base * pow(args.freq_base, 2.0f * inv_ndims * p);
|
const float theta = theta_base * pow(args.freq_base, 2.0f * inv_ndims * p);
|
||||||
// end of mrope
|
// end of mrope
|
||||||
|
|
||||||
const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
|
const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
|
||||||
|
|
||||||
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
|
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
|
||||||
|
|
||||||
|
|
@ -6810,6 +6854,97 @@ kernel void kernel_conv_transpose_1d<half>(
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint3 tgpg[[threadgroups_per_grid]]);
|
uint3 tgpg[[threadgroups_per_grid]]);
|
||||||
|
|
||||||
|
|
||||||
|
typedef void (conv_transpose_2d_t)(
|
||||||
|
constant ggml_metal_kargs_conv_transpose_2d & args,
|
||||||
|
device const float * src0,
|
||||||
|
device const float * src1,
|
||||||
|
device char * dst,
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint3 tgpg[[threadgroups_per_grid]]);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
kernel void kernel_conv_transpose_2d(
|
||||||
|
constant ggml_metal_kargs_conv_transpose_2d & args,
|
||||||
|
device const T * src0,
|
||||||
|
device const float * src1,
|
||||||
|
device char * dst,
|
||||||
|
threadgroup float * shared_sum [[threadgroup(0)]],
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
|
uint3 ntg[[threads_per_threadgroup]]) {
|
||||||
|
|
||||||
|
const int64_t out_x = tgpig[0];
|
||||||
|
const int64_t out_y = tgpig[1];
|
||||||
|
const int64_t out_c = tgpig[2];
|
||||||
|
|
||||||
|
const int64_t kw = tpitg[0];
|
||||||
|
const int64_t kh = tpitg[1];
|
||||||
|
|
||||||
|
float v = 0.0f;
|
||||||
|
|
||||||
|
for (int64_t in_c = 0; in_c < args.IC; in_c++) {
|
||||||
|
int64_t in_y = out_y - kh;
|
||||||
|
|
||||||
|
if (in_y < 0 || in_y % args.s0) continue;
|
||||||
|
|
||||||
|
in_y /= args.s0;
|
||||||
|
|
||||||
|
if (in_y >= args.IH) continue;
|
||||||
|
|
||||||
|
int64_t in_x = out_x - kw;
|
||||||
|
|
||||||
|
if (in_x < 0 || in_x % args.s0) continue;
|
||||||
|
|
||||||
|
in_x /= args.s0;
|
||||||
|
|
||||||
|
if (in_x >= args.IW) continue;
|
||||||
|
|
||||||
|
const int64_t input_idx = (args.IW * args.IH) * in_c + (args.IW) * in_y + in_x;
|
||||||
|
const int64_t kernel_idx = (args.KH * args.KW * args.OC) * in_c + (args.KH * args.KW) * out_c + (args.KW) * kh + kw;
|
||||||
|
|
||||||
|
v += (float)src0[kernel_idx] * src1[input_idx];
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint tid = tpitg.y * ntg.x + tpitg.x;
|
||||||
|
shared_sum[tid] = v;
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
if (tid == 0) {
|
||||||
|
float total = 0.0f;
|
||||||
|
const uint num_threads = ntg.x * ntg.y;
|
||||||
|
for (uint i = 0; i < num_threads; i++) {
|
||||||
|
total += shared_sum[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
device float * dst_ptr = (device float *) (dst + out_x*args.nb0 + out_y * args.nb1 + out_c*args.nb2);
|
||||||
|
dst_ptr[0] = total;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template [[host_name("kernel_conv_transpose_2d_f32_f32")]]
|
||||||
|
kernel void kernel_conv_transpose_2d<float>(
|
||||||
|
constant ggml_metal_kargs_conv_transpose_2d & args,
|
||||||
|
device const float * src0,
|
||||||
|
device const float * src1,
|
||||||
|
device char * dst,
|
||||||
|
threadgroup float * shared_sum [[threadgroup(0)]],
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
|
uint3 ntg[[threads_per_threadgroup]]);
|
||||||
|
|
||||||
|
template [[host_name("kernel_conv_transpose_2d_f16_f32")]]
|
||||||
|
kernel void kernel_conv_transpose_2d<half>(
|
||||||
|
constant ggml_metal_kargs_conv_transpose_2d & args,
|
||||||
|
device const half * src0,
|
||||||
|
device const float * src1,
|
||||||
|
device char * dst,
|
||||||
|
threadgroup float * shared_sum [[threadgroup(0)]],
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
|
uint3 ntg[[threads_per_threadgroup]]);
|
||||||
|
|
||||||
kernel void kernel_upscale_f32(
|
kernel void kernel_upscale_f32(
|
||||||
constant ggml_metal_kargs_upscale & args,
|
constant ggml_metal_kargs_upscale & args,
|
||||||
device const char * src0,
|
device const char * src0,
|
||||||
|
|
@ -7938,8 +8073,30 @@ kernel void kernel_flash_attn_ext(
|
||||||
half, half4, simdgroup_half8x8
|
half, half4, simdgroup_half8x8
|
||||||
//float, float4, simdgroup_float8x8
|
//float, float4, simdgroup_float8x8
|
||||||
|
|
||||||
|
#define FA_TYPES_F32 \
|
||||||
|
half, half4, simdgroup_half8x8, \
|
||||||
|
float, float4x4, simdgroup_float8x8, \
|
||||||
|
float, float4x4, simdgroup_float8x8, \
|
||||||
|
float, simdgroup_float8x8, \
|
||||||
|
float, float2, simdgroup_float8x8, \
|
||||||
|
float, float4, simdgroup_float8x8
|
||||||
|
//half, half4, simdgroup_half8x8
|
||||||
|
|
||||||
typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>) flash_attn_ext_t;
|
typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>) flash_attn_ext_t;
|
||||||
|
|
||||||
|
template [[host_name("kernel_flash_attn_ext_f32_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 32, 32>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_f32_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 40, 40>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_f32_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 64, 64>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_f32_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 80, 80>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_f32_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 96, 96>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_f32_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 112, 112>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_f32_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 128, 128>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_f32_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 192, 192>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_f32_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 192, 128>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_f32_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 256, 256>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_f32_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 576, 512>;
|
||||||
|
|
||||||
|
template [[host_name("kernel_flash_attn_ext_f16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 32, 32>;
|
||||||
template [[host_name("kernel_flash_attn_ext_f16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 40, 40>;
|
template [[host_name("kernel_flash_attn_ext_f16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 40, 40>;
|
||||||
template [[host_name("kernel_flash_attn_ext_f16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>;
|
template [[host_name("kernel_flash_attn_ext_f16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>;
|
||||||
template [[host_name("kernel_flash_attn_ext_f16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 80, 80>;
|
template [[host_name("kernel_flash_attn_ext_f16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 80, 80>;
|
||||||
|
|
@ -7952,6 +8109,7 @@ template [[host_name("kernel_flash_attn_ext_f16_dk256_dv256")]] kernel flash_at
|
||||||
template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 576, 512>;
|
template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 576, 512>;
|
||||||
|
|
||||||
#if defined(GGML_METAL_HAS_BF16)
|
#if defined(GGML_METAL_HAS_BF16)
|
||||||
|
template [[host_name("kernel_flash_attn_ext_bf16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 32, 32>;
|
||||||
template [[host_name("kernel_flash_attn_ext_bf16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 40, 40>;
|
template [[host_name("kernel_flash_attn_ext_bf16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 40, 40>;
|
||||||
template [[host_name("kernel_flash_attn_ext_bf16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
|
template [[host_name("kernel_flash_attn_ext_bf16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
|
||||||
template [[host_name("kernel_flash_attn_ext_bf16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>;
|
template [[host_name("kernel_flash_attn_ext_bf16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>;
|
||||||
|
|
@ -7964,6 +8122,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_dk256_dv256")]] kernel flash_at
|
||||||
template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 576, 512>;
|
template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 576, 512>;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
template [[host_name("kernel_flash_attn_ext_q4_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 32, 32>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q4_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 40, 40>;
|
template [[host_name("kernel_flash_attn_ext_q4_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 40, 40>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q4_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>;
|
template [[host_name("kernel_flash_attn_ext_q4_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q4_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 80, 80>;
|
template [[host_name("kernel_flash_attn_ext_q4_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 80, 80>;
|
||||||
|
|
@ -7975,6 +8134,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv128")]] kernel flash_at
|
||||||
template [[host_name("kernel_flash_attn_ext_q4_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256, 256>;
|
template [[host_name("kernel_flash_attn_ext_q4_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256, 256>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 576, 512>;
|
template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 576, 512>;
|
||||||
|
|
||||||
|
template [[host_name("kernel_flash_attn_ext_q4_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 32, 32>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q4_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 40, 40>;
|
template [[host_name("kernel_flash_attn_ext_q4_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 40, 40>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q4_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64, 64>;
|
template [[host_name("kernel_flash_attn_ext_q4_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64, 64>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q4_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80, 80>;
|
template [[host_name("kernel_flash_attn_ext_q4_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80, 80>;
|
||||||
|
|
@ -7986,6 +8146,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv128")]] kernel flash_at
|
||||||
template [[host_name("kernel_flash_attn_ext_q4_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256, 256>;
|
template [[host_name("kernel_flash_attn_ext_q4_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256, 256>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 576, 512>;
|
template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 576, 512>;
|
||||||
|
|
||||||
|
template [[host_name("kernel_flash_attn_ext_q5_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 32, 32>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q5_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 40, 40>;
|
template [[host_name("kernel_flash_attn_ext_q5_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 40, 40>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q5_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64, 64>;
|
template [[host_name("kernel_flash_attn_ext_q5_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64, 64>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q5_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80, 80>;
|
template [[host_name("kernel_flash_attn_ext_q5_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80, 80>;
|
||||||
|
|
@ -7997,6 +8158,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv128")]] kernel flash_at
|
||||||
template [[host_name("kernel_flash_attn_ext_q5_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256, 256>;
|
template [[host_name("kernel_flash_attn_ext_q5_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256, 256>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 576, 512>;
|
template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 576, 512>;
|
||||||
|
|
||||||
|
template [[host_name("kernel_flash_attn_ext_q5_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 32, 32>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q5_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 40, 40>;
|
template [[host_name("kernel_flash_attn_ext_q5_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 40, 40>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q5_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64, 64>;
|
template [[host_name("kernel_flash_attn_ext_q5_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64, 64>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q5_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80, 80>;
|
template [[host_name("kernel_flash_attn_ext_q5_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80, 80>;
|
||||||
|
|
@ -8008,6 +8170,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv128")]] kernel flash_at
|
||||||
template [[host_name("kernel_flash_attn_ext_q5_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256, 256>;
|
template [[host_name("kernel_flash_attn_ext_q5_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256, 256>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 576, 512>;
|
template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 576, 512>;
|
||||||
|
|
||||||
|
template [[host_name("kernel_flash_attn_ext_q8_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 32, 32>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q8_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 40, 40>;
|
template [[host_name("kernel_flash_attn_ext_q8_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 40, 40>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q8_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64, 64>;
|
template [[host_name("kernel_flash_attn_ext_q8_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64, 64>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q8_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80, 80>;
|
template [[host_name("kernel_flash_attn_ext_q8_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80, 80>;
|
||||||
|
|
@ -8543,8 +8706,28 @@ kernel void kernel_flash_attn_ext_vec(
|
||||||
float, float4, \
|
float, float4, \
|
||||||
float4
|
float4
|
||||||
|
|
||||||
|
#define FA_TYPES_F32 \
|
||||||
|
half4, \
|
||||||
|
float4, \
|
||||||
|
float4, \
|
||||||
|
float, \
|
||||||
|
float, float4, \
|
||||||
|
float4
|
||||||
|
|
||||||
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;
|
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;
|
||||||
|
|
||||||
|
template [[host_name("kernel_flash_attn_ext_vec_f32_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 32, 32, 4>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_vec_f16_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 32, 32, 4>;
|
||||||
|
#if defined(GGML_METAL_HAS_BF16)
|
||||||
|
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 32, 32, 4>;
|
||||||
|
#endif
|
||||||
|
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 32, 32, 4>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 32, 32, 4>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 32, 32, 4>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 32, 32, 4>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 32, 32, 4>;
|
||||||
|
|
||||||
|
template [[host_name("kernel_flash_attn_ext_vec_f32_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 64, 64, 2>;
|
||||||
template [[host_name("kernel_flash_attn_ext_vec_f16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 64, 64, 2>;
|
template [[host_name("kernel_flash_attn_ext_vec_f16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 64, 64, 2>;
|
||||||
#if defined(GGML_METAL_HAS_BF16)
|
#if defined(GGML_METAL_HAS_BF16)
|
||||||
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 64, 64, 2>;
|
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 64, 64, 2>;
|
||||||
|
|
@ -8555,6 +8738,7 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk64_dv64")]] kernel flas
|
||||||
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 64, 64, 2>;
|
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 64, 64, 2>;
|
||||||
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 64, 64, 2>;
|
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 64, 64, 2>;
|
||||||
|
|
||||||
|
template [[host_name("kernel_flash_attn_ext_vec_f32_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 96, 96, 4>;
|
||||||
template [[host_name("kernel_flash_attn_ext_vec_f16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 96, 96, 4>;
|
template [[host_name("kernel_flash_attn_ext_vec_f16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 96, 96, 4>;
|
||||||
#if defined(GGML_METAL_HAS_BF16)
|
#if defined(GGML_METAL_HAS_BF16)
|
||||||
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 96, 96, 4>;
|
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 96, 96, 4>;
|
||||||
|
|
@ -8565,6 +8749,7 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk96_dv96")]] kernel flas
|
||||||
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 96, 96, 4>;
|
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 96, 96, 4>;
|
||||||
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 96, 96, 4>;
|
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 96, 96, 4>;
|
||||||
|
|
||||||
|
template [[host_name("kernel_flash_attn_ext_vec_f32_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 128, 128, 1>;
|
||||||
template [[host_name("kernel_flash_attn_ext_vec_f16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 1>;
|
template [[host_name("kernel_flash_attn_ext_vec_f16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 1>;
|
||||||
#if defined(GGML_METAL_HAS_BF16)
|
#if defined(GGML_METAL_HAS_BF16)
|
||||||
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 128, 128, 1>;
|
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 128, 128, 1>;
|
||||||
|
|
@ -8575,6 +8760,7 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk128_dv128")]] kernel flas
|
||||||
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 128, 128, 1>;
|
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 128, 128, 1>;
|
||||||
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 128, 128, 1>;
|
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 128, 128, 1>;
|
||||||
|
|
||||||
|
template [[host_name("kernel_flash_attn_ext_vec_f32_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 192, 192, 2>;
|
||||||
template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 192, 2>;
|
template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 192, 2>;
|
||||||
#if defined(GGML_METAL_HAS_BF16)
|
#if defined(GGML_METAL_HAS_BF16)
|
||||||
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 192, 2>;
|
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 192, 2>;
|
||||||
|
|
@ -8585,6 +8771,7 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv192")]] kernel flas
|
||||||
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 192, 2>;
|
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 192, 2>;
|
||||||
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 192, 2>;
|
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 192, 2>;
|
||||||
|
|
||||||
|
template [[host_name("kernel_flash_attn_ext_vec_f32_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 192, 128, 2>;
|
||||||
template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 128, 2>;
|
template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 128, 2>;
|
||||||
#if defined(GGML_METAL_HAS_BF16)
|
#if defined(GGML_METAL_HAS_BF16)
|
||||||
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 128, 2>;
|
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 128, 2>;
|
||||||
|
|
@ -8595,6 +8782,7 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv128")]] kernel flas
|
||||||
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 128, 2>;
|
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 128, 2>;
|
||||||
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 128, 2>;
|
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 128, 2>;
|
||||||
|
|
||||||
|
template [[host_name("kernel_flash_attn_ext_vec_f32_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 256, 256, 1>;
|
||||||
template [[host_name("kernel_flash_attn_ext_vec_f16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 256, 256, 1>;
|
template [[host_name("kernel_flash_attn_ext_vec_f16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 256, 256, 1>;
|
||||||
#if defined(GGML_METAL_HAS_BF16)
|
#if defined(GGML_METAL_HAS_BF16)
|
||||||
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 256, 256, 1>;
|
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 256, 256, 1>;
|
||||||
|
|
@ -8605,6 +8793,7 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk256_dv256")]] kernel flas
|
||||||
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 256, 256, 1>;
|
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 256, 256, 1>;
|
||||||
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 256, 256, 1>;
|
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 256, 256, 1>;
|
||||||
|
|
||||||
|
template [[host_name("kernel_flash_attn_ext_vec_f32_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 576, 512, 2>;
|
||||||
template [[host_name("kernel_flash_attn_ext_vec_f16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 576, 512, 2>;
|
template [[host_name("kernel_flash_attn_ext_vec_f16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 576, 512, 2>;
|
||||||
#if defined(GGML_METAL_HAS_BF16)
|
#if defined(GGML_METAL_HAS_BF16)
|
||||||
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 576, 512, 2>;
|
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 576, 512, 2>;
|
||||||
|
|
|
||||||
|
|
@ -251,6 +251,7 @@ typedef struct {
|
||||||
int32_t sect_1;
|
int32_t sect_1;
|
||||||
int32_t sect_2;
|
int32_t sect_2;
|
||||||
int32_t sect_3;
|
int32_t sect_3;
|
||||||
|
bool src2;
|
||||||
} ggml_metal_kargs_rope;
|
} ggml_metal_kargs_rope;
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
|
|
@ -513,6 +514,19 @@ typedef struct {
|
||||||
uint64_t nb1;
|
uint64_t nb1;
|
||||||
} ggml_metal_kargs_conv_transpose_1d;
|
} ggml_metal_kargs_conv_transpose_1d;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
int32_t IC;
|
||||||
|
int32_t IH;
|
||||||
|
int32_t IW;
|
||||||
|
int32_t KH;
|
||||||
|
int32_t KW;
|
||||||
|
int32_t OC;
|
||||||
|
int32_t s0;
|
||||||
|
uint64_t nb0;
|
||||||
|
uint64_t nb1;
|
||||||
|
uint64_t nb2;
|
||||||
|
} ggml_metal_kargs_conv_transpose_2d;
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
uint64_t ofs0;
|
uint64_t ofs0;
|
||||||
uint64_t ofs1;
|
uint64_t ofs1;
|
||||||
|
|
|
||||||
|
|
@ -368,6 +368,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
||||||
{
|
{
|
||||||
n_fuse = ggml_metal_op_conv_transpose_1d(ctx, idx);
|
n_fuse = ggml_metal_op_conv_transpose_1d(ctx, idx);
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_CONV_TRANSPOSE_2D:
|
||||||
|
{
|
||||||
|
n_fuse = ggml_metal_op_conv_transpose_2d(ctx, idx);
|
||||||
|
} break;
|
||||||
case GGML_OP_UPSCALE:
|
case GGML_OP_UPSCALE:
|
||||||
{
|
{
|
||||||
n_fuse = ggml_metal_op_upscale(ctx, idx);
|
n_fuse = ggml_metal_op_upscale(ctx, idx);
|
||||||
|
|
@ -866,12 +870,25 @@ int ggml_metal_op_sum(ggml_metal_op_t ctx, int idx) {
|
||||||
|
|
||||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_sum(lib, op);
|
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_sum(lib, op);
|
||||||
|
|
||||||
|
int nth = 32; // SIMD width
|
||||||
|
|
||||||
|
while (nth < (int) n && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
||||||
|
nth *= 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||||
|
nth = std::min(nth, (int) n);
|
||||||
|
|
||||||
|
const int nsg = (nth + 31) / 32;
|
||||||
|
|
||||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
||||||
|
|
||||||
ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, 1, 1, 1);
|
ggml_metal_encoder_set_threadgroup_memory_size(enc, nsg * sizeof(float), 0);
|
||||||
|
|
||||||
|
ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, nth, 1, 1);
|
||||||
|
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
@ -2969,6 +2986,7 @@ int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) {
|
||||||
/* sect_1 =*/ sect_1,
|
/* sect_1 =*/ sect_1,
|
||||||
/* sect_2 =*/ sect_2,
|
/* sect_2 =*/ sect_2,
|
||||||
/* sect_3 =*/ sect_3,
|
/* sect_3 =*/ sect_3,
|
||||||
|
/* src2 =*/ op->src[2] != nullptr,
|
||||||
};
|
};
|
||||||
|
|
||||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_rope(lib, op);
|
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_rope(lib, op);
|
||||||
|
|
@ -3104,6 +3122,62 @@ int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int ggml_metal_op_conv_transpose_2d(ggml_metal_op_t ctx, int idx) {
|
||||||
|
ggml_tensor * op = ctx->node(idx);
|
||||||
|
|
||||||
|
ggml_metal_library_t lib = ctx->lib;
|
||||||
|
ggml_metal_encoder_t enc = ctx->enc;
|
||||||
|
|
||||||
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||||
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||||
|
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
||||||
|
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
||||||
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||||
|
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||||
|
|
||||||
|
const int32_t s0 = ((const int32_t *)(op->op_params))[0];
|
||||||
|
|
||||||
|
const int32_t IC = op->src[1]->ne[2];
|
||||||
|
const int32_t IH = op->src[1]->ne[1];
|
||||||
|
const int32_t IW = op->src[1]->ne[0];
|
||||||
|
|
||||||
|
const int32_t KH = op->src[0]->ne[1];
|
||||||
|
const int32_t KW = op->src[0]->ne[0];
|
||||||
|
|
||||||
|
const int32_t OW = op->ne[0];
|
||||||
|
const int32_t OH = op->ne[1];
|
||||||
|
const int32_t OC = op->ne[2];
|
||||||
|
|
||||||
|
ggml_metal_kargs_conv_transpose_2d args = {
|
||||||
|
/*.IC =*/ IC,
|
||||||
|
/*.IH =*/ IH,
|
||||||
|
/*.IW =*/ IW,
|
||||||
|
/*.KH =*/ KH,
|
||||||
|
/*.KW =*/ KW,
|
||||||
|
/*.OC =*/ OC,
|
||||||
|
/*.s0 =*/ s0,
|
||||||
|
/*.nb0 =*/ nb0,
|
||||||
|
/*.nb1 =*/ nb1,
|
||||||
|
/*.nb2 =*/ nb2,
|
||||||
|
};
|
||||||
|
|
||||||
|
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_conv_transpose_2d(lib, op);
|
||||||
|
|
||||||
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||||
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||||
|
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||||
|
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
|
||||||
|
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
|
||||||
|
|
||||||
|
// Metal requires buffer size to be multiple of 16 bytes
|
||||||
|
const size_t smem = GGML_PAD(KW * KH * sizeof(float), 16);
|
||||||
|
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
||||||
|
|
||||||
|
ggml_metal_encoder_dispatch_threadgroups(enc, OW, OH, OC, KW, KH, 1);
|
||||||
|
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) {
|
int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) {
|
||||||
ggml_tensor * op = ctx->node(idx);
|
ggml_tensor * op = ctx->node(idx);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -71,6 +71,7 @@ int ggml_metal_op_norm (ggml_metal_op_t ctx, int idx);
|
||||||
int ggml_metal_op_rope (ggml_metal_op_t ctx, int idx);
|
int ggml_metal_op_rope (ggml_metal_op_t ctx, int idx);
|
||||||
int ggml_metal_op_im2col (ggml_metal_op_t ctx, int idx);
|
int ggml_metal_op_im2col (ggml_metal_op_t ctx, int idx);
|
||||||
int ggml_metal_op_conv_transpose_1d (ggml_metal_op_t ctx, int idx);
|
int ggml_metal_op_conv_transpose_1d (ggml_metal_op_t ctx, int idx);
|
||||||
|
int ggml_metal_op_conv_transpose_2d (ggml_metal_op_t ctx, int idx);
|
||||||
int ggml_metal_op_upscale (ggml_metal_op_t ctx, int idx);
|
int ggml_metal_op_upscale (ggml_metal_op_t ctx, int idx);
|
||||||
int ggml_metal_op_pad (ggml_metal_op_t ctx, int idx);
|
int ggml_metal_op_pad (ggml_metal_op_t ctx, int idx);
|
||||||
int ggml_metal_op_pad_reflect_1d (ggml_metal_op_t ctx, int idx);
|
int ggml_metal_op_pad_reflect_1d (ggml_metal_op_t ctx, int idx);
|
||||||
|
|
|
||||||
|
|
@ -1727,18 +1727,48 @@ kernel void kernel_op_sum_f32(
|
||||||
constant ggml_metal_kargs_sum & args,
|
constant ggml_metal_kargs_sum & args,
|
||||||
device const float * src0,
|
device const float * src0,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
ushort tiitg[[thread_index_in_threadgroup]]) {
|
threadgroup float * shmem_f32 [[threadgroup(0)]],
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||||
|
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||||
|
ushort tiisg[[thread_index_in_simdgroup]],
|
||||||
|
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||||
|
|
||||||
if (tiitg != 0) {
|
if (args.np == 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
float acc = 0.0f;
|
const uint nsg = (ntg.x + 31) / 32;
|
||||||
for (ulong i = 0; i < args.np; ++i) {
|
|
||||||
acc += src0[i];
|
float sumf = 0;
|
||||||
|
|
||||||
|
for (int64_t i0 = tpitg.x; i0 < args.np; i0 += ntg.x) {
|
||||||
|
sumf += src0[i0];
|
||||||
}
|
}
|
||||||
|
|
||||||
dst[0] = acc;
|
sumf = simd_sum(sumf);
|
||||||
|
|
||||||
|
if (tiisg == 0) {
|
||||||
|
shmem_f32[sgitg] = sumf;
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
float total = 0;
|
||||||
|
|
||||||
|
if (sgitg == 0) {
|
||||||
|
float v = 0;
|
||||||
|
|
||||||
|
if (tpitg.x < nsg) {
|
||||||
|
v = shmem_f32[tpitg.x];
|
||||||
|
}
|
||||||
|
|
||||||
|
total = simd_sum(v);
|
||||||
|
|
||||||
|
if (tpitg.x == 0) {
|
||||||
|
dst[0] = total;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <bool norm>
|
template <bool norm>
|
||||||
|
|
@ -3748,7 +3778,7 @@ kernel void kernel_rope_norm(
|
||||||
|
|
||||||
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
|
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
|
||||||
|
|
||||||
const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
|
const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
|
||||||
|
|
||||||
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
|
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
|
||||||
|
|
||||||
|
|
@ -3801,7 +3831,7 @@ kernel void kernel_rope_neox(
|
||||||
|
|
||||||
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
|
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
|
||||||
|
|
||||||
const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
|
const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
|
||||||
|
|
||||||
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
|
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
|
||||||
|
|
||||||
|
|
@ -3868,7 +3898,7 @@ kernel void kernel_rope_multi(
|
||||||
|
|
||||||
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
|
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
|
||||||
|
|
||||||
const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
|
const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
|
||||||
|
|
||||||
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
|
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
|
||||||
|
|
||||||
|
|
@ -3935,7 +3965,7 @@ kernel void kernel_rope_vision(
|
||||||
const float theta = theta_base * pow(args.freq_base, 2.0f * inv_ndims * p);
|
const float theta = theta_base * pow(args.freq_base, 2.0f * inv_ndims * p);
|
||||||
// end of mrope
|
// end of mrope
|
||||||
|
|
||||||
const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
|
const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
|
||||||
|
|
||||||
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
|
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
|
||||||
|
|
||||||
|
|
@ -4145,6 +4175,97 @@ kernel void kernel_conv_transpose_1d<half>(
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint3 tgpg[[threadgroups_per_grid]]);
|
uint3 tgpg[[threadgroups_per_grid]]);
|
||||||
|
|
||||||
|
|
||||||
|
typedef void (conv_transpose_2d_t)(
|
||||||
|
constant ggml_metal_kargs_conv_transpose_2d & args,
|
||||||
|
device const float * src0,
|
||||||
|
device const float * src1,
|
||||||
|
device char * dst,
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint3 tgpg[[threadgroups_per_grid]]);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
kernel void kernel_conv_transpose_2d(
|
||||||
|
constant ggml_metal_kargs_conv_transpose_2d & args,
|
||||||
|
device const T * src0,
|
||||||
|
device const float * src1,
|
||||||
|
device char * dst,
|
||||||
|
threadgroup float * shared_sum [[threadgroup(0)]],
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
|
uint3 ntg[[threads_per_threadgroup]]) {
|
||||||
|
|
||||||
|
const int64_t out_x = tgpig[0];
|
||||||
|
const int64_t out_y = tgpig[1];
|
||||||
|
const int64_t out_c = tgpig[2];
|
||||||
|
|
||||||
|
const int64_t kw = tpitg[0];
|
||||||
|
const int64_t kh = tpitg[1];
|
||||||
|
|
||||||
|
float v = 0.0f;
|
||||||
|
|
||||||
|
for (int64_t in_c = 0; in_c < args.IC; in_c++) {
|
||||||
|
int64_t in_y = out_y - kh;
|
||||||
|
|
||||||
|
if (in_y < 0 || in_y % args.s0) continue;
|
||||||
|
|
||||||
|
in_y /= args.s0;
|
||||||
|
|
||||||
|
if (in_y >= args.IH) continue;
|
||||||
|
|
||||||
|
int64_t in_x = out_x - kw;
|
||||||
|
|
||||||
|
if (in_x < 0 || in_x % args.s0) continue;
|
||||||
|
|
||||||
|
in_x /= args.s0;
|
||||||
|
|
||||||
|
if (in_x >= args.IW) continue;
|
||||||
|
|
||||||
|
const int64_t input_idx = (args.IW * args.IH) * in_c + (args.IW) * in_y + in_x;
|
||||||
|
const int64_t kernel_idx = (args.KH * args.KW * args.OC) * in_c + (args.KH * args.KW) * out_c + (args.KW) * kh + kw;
|
||||||
|
|
||||||
|
v += (float)src0[kernel_idx] * src1[input_idx];
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint tid = tpitg.y * ntg.x + tpitg.x;
|
||||||
|
shared_sum[tid] = v;
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
if (tid == 0) {
|
||||||
|
float total = 0.0f;
|
||||||
|
const uint num_threads = ntg.x * ntg.y;
|
||||||
|
for (uint i = 0; i < num_threads; i++) {
|
||||||
|
total += shared_sum[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
device float * dst_ptr = (device float *) (dst + out_x*args.nb0 + out_y * args.nb1 + out_c*args.nb2);
|
||||||
|
dst_ptr[0] = total;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template [[host_name("kernel_conv_transpose_2d_f32_f32")]]
|
||||||
|
kernel void kernel_conv_transpose_2d<float>(
|
||||||
|
constant ggml_metal_kargs_conv_transpose_2d & args,
|
||||||
|
device const float * src0,
|
||||||
|
device const float * src1,
|
||||||
|
device char * dst,
|
||||||
|
threadgroup float * shared_sum [[threadgroup(0)]],
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
|
uint3 ntg[[threads_per_threadgroup]]);
|
||||||
|
|
||||||
|
template [[host_name("kernel_conv_transpose_2d_f16_f32")]]
|
||||||
|
kernel void kernel_conv_transpose_2d<half>(
|
||||||
|
constant ggml_metal_kargs_conv_transpose_2d & args,
|
||||||
|
device const half * src0,
|
||||||
|
device const float * src1,
|
||||||
|
device char * dst,
|
||||||
|
threadgroup float * shared_sum [[threadgroup(0)]],
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
|
uint3 ntg[[threads_per_threadgroup]]);
|
||||||
|
|
||||||
kernel void kernel_upscale_f32(
|
kernel void kernel_upscale_f32(
|
||||||
constant ggml_metal_kargs_upscale & args,
|
constant ggml_metal_kargs_upscale & args,
|
||||||
device const char * src0,
|
device const char * src0,
|
||||||
|
|
@ -5273,8 +5394,30 @@ kernel void kernel_flash_attn_ext(
|
||||||
half, half4, simdgroup_half8x8
|
half, half4, simdgroup_half8x8
|
||||||
//float, float4, simdgroup_float8x8
|
//float, float4, simdgroup_float8x8
|
||||||
|
|
||||||
|
#define FA_TYPES_F32 \
|
||||||
|
half, half4, simdgroup_half8x8, \
|
||||||
|
float, float4x4, simdgroup_float8x8, \
|
||||||
|
float, float4x4, simdgroup_float8x8, \
|
||||||
|
float, simdgroup_float8x8, \
|
||||||
|
float, float2, simdgroup_float8x8, \
|
||||||
|
float, float4, simdgroup_float8x8
|
||||||
|
//half, half4, simdgroup_half8x8
|
||||||
|
|
||||||
typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>) flash_attn_ext_t;
|
typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>) flash_attn_ext_t;
|
||||||
|
|
||||||
|
template [[host_name("kernel_flash_attn_ext_f32_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 32, 32>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_f32_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 40, 40>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_f32_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 64, 64>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_f32_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 80, 80>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_f32_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 96, 96>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_f32_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 112, 112>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_f32_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 128, 128>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_f32_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 192, 192>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_f32_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 192, 128>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_f32_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 256, 256>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_f32_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 576, 512>;
|
||||||
|
|
||||||
|
template [[host_name("kernel_flash_attn_ext_f16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 32, 32>;
|
||||||
template [[host_name("kernel_flash_attn_ext_f16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 40, 40>;
|
template [[host_name("kernel_flash_attn_ext_f16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 40, 40>;
|
||||||
template [[host_name("kernel_flash_attn_ext_f16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>;
|
template [[host_name("kernel_flash_attn_ext_f16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>;
|
||||||
template [[host_name("kernel_flash_attn_ext_f16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 80, 80>;
|
template [[host_name("kernel_flash_attn_ext_f16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 80, 80>;
|
||||||
|
|
@ -5287,6 +5430,7 @@ template [[host_name("kernel_flash_attn_ext_f16_dk256_dv256")]] kernel flash_at
|
||||||
template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 576, 512>;
|
template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 576, 512>;
|
||||||
|
|
||||||
#if defined(GGML_METAL_HAS_BF16)
|
#if defined(GGML_METAL_HAS_BF16)
|
||||||
|
template [[host_name("kernel_flash_attn_ext_bf16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 32, 32>;
|
||||||
template [[host_name("kernel_flash_attn_ext_bf16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 40, 40>;
|
template [[host_name("kernel_flash_attn_ext_bf16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 40, 40>;
|
||||||
template [[host_name("kernel_flash_attn_ext_bf16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
|
template [[host_name("kernel_flash_attn_ext_bf16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
|
||||||
template [[host_name("kernel_flash_attn_ext_bf16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>;
|
template [[host_name("kernel_flash_attn_ext_bf16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>;
|
||||||
|
|
@ -5299,6 +5443,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_dk256_dv256")]] kernel flash_at
|
||||||
template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 576, 512>;
|
template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 576, 512>;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
template [[host_name("kernel_flash_attn_ext_q4_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 32, 32>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q4_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 40, 40>;
|
template [[host_name("kernel_flash_attn_ext_q4_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 40, 40>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q4_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>;
|
template [[host_name("kernel_flash_attn_ext_q4_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q4_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 80, 80>;
|
template [[host_name("kernel_flash_attn_ext_q4_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 80, 80>;
|
||||||
|
|
@ -5310,6 +5455,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv128")]] kernel flash_at
|
||||||
template [[host_name("kernel_flash_attn_ext_q4_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256, 256>;
|
template [[host_name("kernel_flash_attn_ext_q4_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256, 256>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 576, 512>;
|
template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 576, 512>;
|
||||||
|
|
||||||
|
template [[host_name("kernel_flash_attn_ext_q4_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 32, 32>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q4_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 40, 40>;
|
template [[host_name("kernel_flash_attn_ext_q4_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 40, 40>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q4_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64, 64>;
|
template [[host_name("kernel_flash_attn_ext_q4_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64, 64>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q4_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80, 80>;
|
template [[host_name("kernel_flash_attn_ext_q4_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80, 80>;
|
||||||
|
|
@ -5321,6 +5467,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv128")]] kernel flash_at
|
||||||
template [[host_name("kernel_flash_attn_ext_q4_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256, 256>;
|
template [[host_name("kernel_flash_attn_ext_q4_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256, 256>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 576, 512>;
|
template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 576, 512>;
|
||||||
|
|
||||||
|
template [[host_name("kernel_flash_attn_ext_q5_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 32, 32>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q5_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 40, 40>;
|
template [[host_name("kernel_flash_attn_ext_q5_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 40, 40>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q5_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64, 64>;
|
template [[host_name("kernel_flash_attn_ext_q5_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64, 64>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q5_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80, 80>;
|
template [[host_name("kernel_flash_attn_ext_q5_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80, 80>;
|
||||||
|
|
@ -5332,6 +5479,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv128")]] kernel flash_at
|
||||||
template [[host_name("kernel_flash_attn_ext_q5_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256, 256>;
|
template [[host_name("kernel_flash_attn_ext_q5_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256, 256>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 576, 512>;
|
template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 576, 512>;
|
||||||
|
|
||||||
|
template [[host_name("kernel_flash_attn_ext_q5_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 32, 32>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q5_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 40, 40>;
|
template [[host_name("kernel_flash_attn_ext_q5_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 40, 40>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q5_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64, 64>;
|
template [[host_name("kernel_flash_attn_ext_q5_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64, 64>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q5_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80, 80>;
|
template [[host_name("kernel_flash_attn_ext_q5_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80, 80>;
|
||||||
|
|
@ -5343,6 +5491,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv128")]] kernel flash_at
|
||||||
template [[host_name("kernel_flash_attn_ext_q5_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256, 256>;
|
template [[host_name("kernel_flash_attn_ext_q5_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256, 256>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 576, 512>;
|
template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 576, 512>;
|
||||||
|
|
||||||
|
template [[host_name("kernel_flash_attn_ext_q8_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 32, 32>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q8_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 40, 40>;
|
template [[host_name("kernel_flash_attn_ext_q8_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 40, 40>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q8_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64, 64>;
|
template [[host_name("kernel_flash_attn_ext_q8_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64, 64>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q8_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80, 80>;
|
template [[host_name("kernel_flash_attn_ext_q8_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80, 80>;
|
||||||
|
|
@ -5878,8 +6027,28 @@ kernel void kernel_flash_attn_ext_vec(
|
||||||
float, float4, \
|
float, float4, \
|
||||||
float4
|
float4
|
||||||
|
|
||||||
|
#define FA_TYPES_F32 \
|
||||||
|
half4, \
|
||||||
|
float4, \
|
||||||
|
float4, \
|
||||||
|
float, \
|
||||||
|
float, float4, \
|
||||||
|
float4
|
||||||
|
|
||||||
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;
|
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;
|
||||||
|
|
||||||
|
template [[host_name("kernel_flash_attn_ext_vec_f32_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 32, 32, 4>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_vec_f16_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 32, 32, 4>;
|
||||||
|
#if defined(GGML_METAL_HAS_BF16)
|
||||||
|
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 32, 32, 4>;
|
||||||
|
#endif
|
||||||
|
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 32, 32, 4>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 32, 32, 4>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 32, 32, 4>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 32, 32, 4>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 32, 32, 4>;
|
||||||
|
|
||||||
|
template [[host_name("kernel_flash_attn_ext_vec_f32_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 64, 64, 2>;
|
||||||
template [[host_name("kernel_flash_attn_ext_vec_f16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 64, 64, 2>;
|
template [[host_name("kernel_flash_attn_ext_vec_f16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 64, 64, 2>;
|
||||||
#if defined(GGML_METAL_HAS_BF16)
|
#if defined(GGML_METAL_HAS_BF16)
|
||||||
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 64, 64, 2>;
|
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 64, 64, 2>;
|
||||||
|
|
@ -5890,6 +6059,7 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk64_dv64")]] kernel flas
|
||||||
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 64, 64, 2>;
|
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 64, 64, 2>;
|
||||||
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 64, 64, 2>;
|
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 64, 64, 2>;
|
||||||
|
|
||||||
|
template [[host_name("kernel_flash_attn_ext_vec_f32_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 96, 96, 4>;
|
||||||
template [[host_name("kernel_flash_attn_ext_vec_f16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 96, 96, 4>;
|
template [[host_name("kernel_flash_attn_ext_vec_f16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 96, 96, 4>;
|
||||||
#if defined(GGML_METAL_HAS_BF16)
|
#if defined(GGML_METAL_HAS_BF16)
|
||||||
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 96, 96, 4>;
|
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 96, 96, 4>;
|
||||||
|
|
@ -5900,6 +6070,7 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk96_dv96")]] kernel flas
|
||||||
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 96, 96, 4>;
|
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 96, 96, 4>;
|
||||||
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 96, 96, 4>;
|
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 96, 96, 4>;
|
||||||
|
|
||||||
|
template [[host_name("kernel_flash_attn_ext_vec_f32_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 128, 128, 1>;
|
||||||
template [[host_name("kernel_flash_attn_ext_vec_f16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 1>;
|
template [[host_name("kernel_flash_attn_ext_vec_f16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 1>;
|
||||||
#if defined(GGML_METAL_HAS_BF16)
|
#if defined(GGML_METAL_HAS_BF16)
|
||||||
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 128, 128, 1>;
|
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 128, 128, 1>;
|
||||||
|
|
@ -5910,6 +6081,7 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk128_dv128")]] kernel flas
|
||||||
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 128, 128, 1>;
|
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 128, 128, 1>;
|
||||||
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 128, 128, 1>;
|
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 128, 128, 1>;
|
||||||
|
|
||||||
|
template [[host_name("kernel_flash_attn_ext_vec_f32_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 192, 192, 2>;
|
||||||
template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 192, 2>;
|
template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 192, 2>;
|
||||||
#if defined(GGML_METAL_HAS_BF16)
|
#if defined(GGML_METAL_HAS_BF16)
|
||||||
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 192, 2>;
|
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 192, 2>;
|
||||||
|
|
@ -5920,6 +6092,7 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv192")]] kernel flas
|
||||||
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 192, 2>;
|
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 192, 2>;
|
||||||
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 192, 2>;
|
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 192, 2>;
|
||||||
|
|
||||||
|
template [[host_name("kernel_flash_attn_ext_vec_f32_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 192, 128, 2>;
|
||||||
template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 128, 2>;
|
template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 128, 2>;
|
||||||
#if defined(GGML_METAL_HAS_BF16)
|
#if defined(GGML_METAL_HAS_BF16)
|
||||||
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 128, 2>;
|
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 128, 2>;
|
||||||
|
|
@ -5930,6 +6103,7 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv128")]] kernel flas
|
||||||
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 128, 2>;
|
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 128, 2>;
|
||||||
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 128, 2>;
|
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 128, 2>;
|
||||||
|
|
||||||
|
template [[host_name("kernel_flash_attn_ext_vec_f32_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 256, 256, 1>;
|
||||||
template [[host_name("kernel_flash_attn_ext_vec_f16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 256, 256, 1>;
|
template [[host_name("kernel_flash_attn_ext_vec_f16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 256, 256, 1>;
|
||||||
#if defined(GGML_METAL_HAS_BF16)
|
#if defined(GGML_METAL_HAS_BF16)
|
||||||
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 256, 256, 1>;
|
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 256, 256, 1>;
|
||||||
|
|
@ -5940,6 +6114,7 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk256_dv256")]] kernel flas
|
||||||
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 256, 256, 1>;
|
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 256, 256, 1>;
|
||||||
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 256, 256, 1>;
|
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 256, 256, 1>;
|
||||||
|
|
||||||
|
template [[host_name("kernel_flash_attn_ext_vec_f32_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 576, 512, 2>;
|
||||||
template [[host_name("kernel_flash_attn_ext_vec_f16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 576, 512, 2>;
|
template [[host_name("kernel_flash_attn_ext_vec_f16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 576, 512, 2>;
|
||||||
#if defined(GGML_METAL_HAS_BF16)
|
#if defined(GGML_METAL_HAS_BF16)
|
||||||
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 576, 512, 2>;
|
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 576, 512, 2>;
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,18 @@
|
||||||
cmake_minimum_required(VERSION 3.19)
|
cmake_minimum_required(VERSION 3.19)
|
||||||
cmake_policy(SET CMP0114 NEW)
|
cmake_policy(SET CMP0114 NEW)
|
||||||
cmake_policy(SET CMP0116 NEW)
|
cmake_policy(SET CMP0116 NEW)
|
||||||
|
if (POLICY CMP0147)
|
||||||
|
# Parallel build custom build steps
|
||||||
|
cmake_policy(SET CMP0147 NEW)
|
||||||
|
endif()
|
||||||
|
|
||||||
find_package(Vulkan COMPONENTS glslc REQUIRED)
|
find_package(Vulkan COMPONENTS glslc REQUIRED)
|
||||||
|
|
||||||
|
if (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC")
|
||||||
|
# Parallel build object files
|
||||||
|
add_definitions(/MP)
|
||||||
|
endif()
|
||||||
|
|
||||||
function(detect_host_compiler)
|
function(detect_host_compiler)
|
||||||
if (CMAKE_HOST_SYSTEM_NAME STREQUAL "Windows")
|
if (CMAKE_HOST_SYSTEM_NAME STREQUAL "Windows")
|
||||||
find_program(HOST_C_COMPILER NAMES cl gcc clang NO_CMAKE_FIND_ROOT_PATH)
|
find_program(HOST_C_COMPILER NAMES cl gcc clang NO_CMAKE_FIND_ROOT_PATH)
|
||||||
|
|
|
||||||
|
|
@ -97,8 +97,6 @@ static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
|
||||||
|
|
||||||
#define GGML_VK_MAX_NODES 8192
|
#define GGML_VK_MAX_NODES 8192
|
||||||
|
|
||||||
#define MAX_VK_BUFFERS 256
|
|
||||||
|
|
||||||
#define VK_CHECK(err, msg) \
|
#define VK_CHECK(err, msg) \
|
||||||
do { \
|
do { \
|
||||||
vk::Result err_ = (err); \
|
vk::Result err_ = (err); \
|
||||||
|
|
@ -387,6 +385,14 @@ enum shader_reduction_mode {
|
||||||
|
|
||||||
static constexpr uint32_t num_argsort_pipelines = 11;
|
static constexpr uint32_t num_argsort_pipelines = 11;
|
||||||
static constexpr uint32_t max_argsort_cols = 1 << (num_argsort_pipelines-1);
|
static constexpr uint32_t max_argsort_cols = 1 << (num_argsort_pipelines-1);
|
||||||
|
static constexpr uint32_t num_topk_moe_pipelines = 10;
|
||||||
|
|
||||||
|
static constexpr std::array topk_moe_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
|
||||||
|
GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
|
||||||
|
GGML_OP_SUM_ROWS, GGML_OP_DIV, GGML_OP_RESHAPE };
|
||||||
|
static constexpr std::array topk_moe { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
|
||||||
|
GGML_OP_VIEW, GGML_OP_GET_ROWS };
|
||||||
|
|
||||||
|
|
||||||
struct vk_device_struct {
|
struct vk_device_struct {
|
||||||
std::recursive_mutex mutex;
|
std::recursive_mutex mutex;
|
||||||
|
|
@ -584,6 +590,9 @@ struct vk_device_struct {
|
||||||
vk_pipeline pipeline_pool2d_f32;
|
vk_pipeline pipeline_pool2d_f32;
|
||||||
vk_pipeline pipeline_rwkv_wkv6_f32;
|
vk_pipeline pipeline_rwkv_wkv6_f32;
|
||||||
vk_pipeline pipeline_rwkv_wkv7_f32;
|
vk_pipeline pipeline_rwkv_wkv7_f32;
|
||||||
|
vk_pipeline pipeline_ssm_scan_f32_d128;
|
||||||
|
vk_pipeline pipeline_ssm_scan_f32_d256;
|
||||||
|
vk_pipeline pipeline_ssm_conv_f32;
|
||||||
vk_pipeline pipeline_opt_step_adamw_f32;
|
vk_pipeline pipeline_opt_step_adamw_f32;
|
||||||
vk_pipeline pipeline_opt_step_sgd_f32;
|
vk_pipeline pipeline_opt_step_sgd_f32;
|
||||||
vk_pipeline pipeline_conv2d_f32[CONV_SHAPE_COUNT];
|
vk_pipeline pipeline_conv2d_f32[CONV_SHAPE_COUNT];
|
||||||
|
|
@ -597,6 +606,9 @@ struct vk_device_struct {
|
||||||
|
|
||||||
vk_pipeline pipeline_flash_attn_split_k_reduce;
|
vk_pipeline pipeline_flash_attn_split_k_reduce;
|
||||||
|
|
||||||
|
// [2] is {!norm, norm}
|
||||||
|
vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][2];
|
||||||
|
|
||||||
std::vector<vk_pipeline_ref> all_pipelines;
|
std::vector<vk_pipeline_ref> all_pipelines;
|
||||||
|
|
||||||
std::vector<std::tuple<void*, size_t, vk_buffer>> pinned_memory;
|
std::vector<std::tuple<void*, size_t, vk_buffer>> pinned_memory;
|
||||||
|
|
@ -940,6 +952,11 @@ struct vk_op_multi_add_push_constants {
|
||||||
static_assert(MAX_PARAMETER_COUNT == 12);
|
static_assert(MAX_PARAMETER_COUNT == 12);
|
||||||
static_assert(sizeof(vk_op_multi_add_push_constants) <= 256);
|
static_assert(sizeof(vk_op_multi_add_push_constants) <= 256);
|
||||||
|
|
||||||
|
struct vk_op_topk_moe_push_constants {
|
||||||
|
uint32_t n_rows;
|
||||||
|
uint32_t n_expert_used;
|
||||||
|
};
|
||||||
|
|
||||||
struct vk_op_add_id_push_constants {
|
struct vk_op_add_id_push_constants {
|
||||||
uint32_t ne0;
|
uint32_t ne0;
|
||||||
uint32_t ne1;
|
uint32_t ne1;
|
||||||
|
|
@ -1089,6 +1106,19 @@ struct vk_op_rwkv_wkv7_push_constants {
|
||||||
uint32_t C;
|
uint32_t C;
|
||||||
uint32_t H;
|
uint32_t H;
|
||||||
};
|
};
|
||||||
|
struct vk_op_ssm_scan_push_constants {
|
||||||
|
uint32_t nb02, nb03, nb12, nb13;
|
||||||
|
uint32_t nb21, nb22, nb31;
|
||||||
|
uint32_t nb42, nb43, nb52, nb53;
|
||||||
|
uint32_t s_off;
|
||||||
|
uint32_t n_head, d_head, n_group, n_tok;
|
||||||
|
};
|
||||||
|
struct vk_op_ssm_conv_push_constants {
|
||||||
|
uint32_t nb01, nb02;
|
||||||
|
uint32_t nb11;
|
||||||
|
uint32_t dst_nb0, dst_nb1, dst_nb2;
|
||||||
|
uint32_t nc, ncs, nr, n_t, n_s;
|
||||||
|
};
|
||||||
|
|
||||||
struct vk_op_conv2d_push_constants {
|
struct vk_op_conv2d_push_constants {
|
||||||
uint32_t Cout;
|
uint32_t Cout;
|
||||||
|
|
@ -1281,7 +1311,6 @@ struct ggml_vk_garbage_collector {
|
||||||
std::vector<vk_semaphore> tl_semaphores;
|
std::vector<vk_semaphore> tl_semaphores;
|
||||||
std::vector<vk_semaphore> semaphores;
|
std::vector<vk_semaphore> semaphores;
|
||||||
std::vector<vk::Event> events;
|
std::vector<vk::Event> events;
|
||||||
std::vector<vk_buffer> temp_buffers;
|
|
||||||
std::vector<vk_context> contexts;
|
std::vector<vk_context> contexts;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -1452,8 +1481,6 @@ struct ggml_backend_vk_context {
|
||||||
// and set to true after the buffer contents are consumed.
|
// and set to true after the buffer contents are consumed.
|
||||||
bool prealloc_x_need_sync, prealloc_y_need_sync, prealloc_split_k_need_sync;
|
bool prealloc_x_need_sync, prealloc_y_need_sync, prealloc_split_k_need_sync;
|
||||||
|
|
||||||
vk_buffer buffer_pool[MAX_VK_BUFFERS];
|
|
||||||
|
|
||||||
vk_context_ref compute_ctx;
|
vk_context_ref compute_ctx;
|
||||||
vk_context_ref transfer_ctx;
|
vk_context_ref transfer_ctx;
|
||||||
|
|
||||||
|
|
@ -2651,11 +2678,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
} \
|
} \
|
||||||
}
|
}
|
||||||
|
|
||||||
|
CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, )
|
||||||
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
|
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
|
||||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
|
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
|
||||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
|
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
|
||||||
#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
||||||
if (device->coopmat1_fa_support) {
|
if (device->coopmat1_fa_support) {
|
||||||
|
CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT1, _cm1)
|
||||||
CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT1, _cm1)
|
CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT1, _cm1)
|
||||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT1, _cm1)
|
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT1, _cm1)
|
||||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT1, _cm1)
|
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT1, _cm1)
|
||||||
|
|
@ -2663,6 +2692,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
#endif
|
#endif
|
||||||
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||||
if (device->coopmat2) {
|
if (device->coopmat2) {
|
||||||
|
CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT2, _cm2)
|
||||||
CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT2, _cm2)
|
CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT2, _cm2)
|
||||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT2, _cm2)
|
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT2, _cm2)
|
||||||
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT2, _cm2)
|
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT2, _cm2)
|
||||||
|
|
@ -3590,6 +3620,16 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
|
||||||
|
|
||||||
|
if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1, true, true);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true);
|
||||||
|
} else {
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1, true, true);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 3, sizeof(vk_op_ssm_conv_push_constants), {32, 1, 1}, {32}, 1);
|
||||||
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
||||||
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_opt_step_sgd_f32, "opt_step_sgd_f32", opt_step_sgd_f32_len, opt_step_sgd_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_opt_step_sgd_f32, "opt_step_sgd_f32", opt_step_sgd_f32_len, opt_step_sgd_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
||||||
|
|
@ -3700,6 +3740,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f16_f32, "conv2d_dw_whcn_f16_f32", conv2d_dw_whcn_f16_f32_len, conv2d_dw_whcn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f16_f32, "conv2d_dw_whcn_f16_f32", conv2d_dw_whcn_f16_f32_len, conv2d_dw_whcn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, "conv2d_dw_cwhn_f16_f32", conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, "conv2d_dw_cwhn_f16_f32", conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
|
||||||
|
|
||||||
|
for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) {
|
||||||
|
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][0], "topk_moe_f32_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0}, 1, true, true);
|
||||||
|
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][1], "topk_moe_f32_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 1}, 1, true, true);
|
||||||
|
}
|
||||||
|
|
||||||
for (auto &c : compiles) {
|
for (auto &c : compiles) {
|
||||||
c.wait();
|
c.wait();
|
||||||
}
|
}
|
||||||
|
|
@ -4690,7 +4735,14 @@ static void ggml_vk_instance_init() {
|
||||||
vk::PhysicalDeviceIDProperties old_id;
|
vk::PhysicalDeviceIDProperties old_id;
|
||||||
old_props.pNext = &old_id;
|
old_props.pNext = &old_id;
|
||||||
devices[k].getProperties2(&old_props);
|
devices[k].getProperties2(&old_props);
|
||||||
return std::equal(std::begin(old_id.deviceUUID), std::end(old_id.deviceUUID), std::begin(new_id.deviceUUID));
|
|
||||||
|
bool equals = std::equal(std::begin(old_id.deviceUUID), std::end(old_id.deviceUUID), std::begin(new_id.deviceUUID));
|
||||||
|
equals = equals || (
|
||||||
|
old_id.deviceLUIDValid && new_id.deviceLUIDValid &&
|
||||||
|
std::equal(std::begin(old_id.deviceLUID), std::end(old_id.deviceLUID), std::begin(new_id.deviceLUID))
|
||||||
|
);
|
||||||
|
|
||||||
|
return equals;
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
if (old_device == vk_instance.device_indices.end()) {
|
if (old_device == vk_instance.device_indices.end()) {
|
||||||
|
|
@ -4728,6 +4780,7 @@ static void ggml_vk_instance_init() {
|
||||||
#endif
|
#endif
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
driver_priorities[vk::DriverId::eMesaDozen] = 100;
|
||||||
|
|
||||||
if (driver_priorities.count(old_driver.driverID)) {
|
if (driver_priorities.count(old_driver.driverID)) {
|
||||||
old_priority = driver_priorities[old_driver.driverID];
|
old_priority = driver_priorities[old_driver.driverID];
|
||||||
|
|
@ -5101,71 +5154,6 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context
|
||||||
return ctx->device->pipeline_dequant_mul_mat_vec_id_f32[a_type];
|
return ctx->device->pipeline_dequant_mul_mat_vec_id_f32[a_type];
|
||||||
}
|
}
|
||||||
|
|
||||||
static vk_buffer ggml_vk_pool_malloc(ggml_backend_vk_context * ctx, size_t size) {
|
|
||||||
VK_LOG_DEBUG("ggml_vk_pool_malloc(" << size << ")");
|
|
||||||
VK_LOG_MEMORY("ggml_vk_pool_malloc");
|
|
||||||
|
|
||||||
int best_i = -1;
|
|
||||||
size_t best_size = std::numeric_limits<size_t>::max(); //smallest unused buffer that fits our needs
|
|
||||||
int worst_i = -1;
|
|
||||||
size_t worst_size = 0; //largest unused buffer seen so far
|
|
||||||
for (int i = 0; i < MAX_VK_BUFFERS; ++i) {
|
|
||||||
vk_buffer &b = ctx->buffer_pool[i];
|
|
||||||
if (b != nullptr && b->size >= size && b->size < best_size) {
|
|
||||||
best_i = i;
|
|
||||||
best_size = b->size;
|
|
||||||
}
|
|
||||||
if (b != nullptr && b->size > worst_size) {
|
|
||||||
worst_i = i;
|
|
||||||
worst_size = b->size;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if(best_i != -1) {
|
|
||||||
//found the smallest buffer that fits our needs
|
|
||||||
vk_buffer b = ctx->buffer_pool[best_i];
|
|
||||||
ctx->buffer_pool[best_i].reset();
|
|
||||||
return b;
|
|
||||||
}
|
|
||||||
if(worst_i != -1) {
|
|
||||||
//no buffer that fits our needs, resize largest one to save memory
|
|
||||||
vk_buffer& b = ctx->buffer_pool[worst_i];
|
|
||||||
ggml_vk_destroy_buffer(b);
|
|
||||||
}
|
|
||||||
|
|
||||||
return ggml_vk_create_buffer_device(ctx->device, size);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void ggml_vk_pool_free(ggml_backend_vk_context * ctx, vk_buffer& buffer) {
|
|
||||||
VK_LOG_DEBUG("ggml_vk_pool_free(" << buffer->size << ")");
|
|
||||||
for (int i = 0; i < MAX_VK_BUFFERS; ++i) {
|
|
||||||
vk_buffer& b = ctx->buffer_pool[i];
|
|
||||||
if (b == nullptr) {
|
|
||||||
b = buffer;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
std::cerr << "ggml_vulkan: WARNING: vk buffer pool full, increase MAX_VK_BUFFERS" << std::endl;
|
|
||||||
ggml_vk_destroy_buffer(buffer);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns an available temporary buffer that may only be used temporarily, it will be reused
|
|
||||||
static vk_buffer ggml_vk_create_buffer_temp(ggml_backend_vk_context * ctx, size_t size) {
|
|
||||||
// Try to find existing temp buffer with enough capacity
|
|
||||||
for (auto& buffer : ctx->gc.temp_buffers) {
|
|
||||||
if (buffer->size >= size) {
|
|
||||||
return buffer;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
VK_LOG_MEMORY("ggml_vk_create_buffer_temp(" << size << ")");
|
|
||||||
|
|
||||||
// Otherwise create new buffer
|
|
||||||
vk_buffer buf = ggml_vk_pool_malloc(ctx, size);
|
|
||||||
ctx->gc.temp_buffers.push_back(buf);
|
|
||||||
|
|
||||||
return buf;
|
|
||||||
}
|
|
||||||
|
|
||||||
static void * ggml_vk_host_malloc(vk_device& device, size_t size) {
|
static void * ggml_vk_host_malloc(vk_device& device, size_t size) {
|
||||||
VK_LOG_MEMORY("ggml_vk_host_malloc(" << size << ")");
|
VK_LOG_MEMORY("ggml_vk_host_malloc(" << size << ")");
|
||||||
vk_buffer buf = ggml_vk_create_buffer(device, size,
|
vk_buffer buf = ggml_vk_create_buffer(device, size,
|
||||||
|
|
@ -7459,8 +7447,16 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||||
}
|
}
|
||||||
|
|
||||||
const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type));
|
const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type));
|
||||||
const uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type));
|
uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type));
|
||||||
const uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type));
|
uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type));
|
||||||
|
|
||||||
|
// For F32, the shader treats it as a block of size 4 (for vec4 loads)
|
||||||
|
if (k->type == GGML_TYPE_F32) {
|
||||||
|
k_stride /= 4;
|
||||||
|
}
|
||||||
|
if (v->type == GGML_TYPE_F32) {
|
||||||
|
v_stride /= 4;
|
||||||
|
}
|
||||||
|
|
||||||
uint32_t alignment = fa_align(path, HSK, HSV, k->type, small_rows);
|
uint32_t alignment = fa_align(path, HSK, HSV, k->type, small_rows);
|
||||||
bool aligned = (KV % alignment) == 0 &&
|
bool aligned = (KV % alignment) == 0 &&
|
||||||
|
|
@ -7974,6 +7970,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||||
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
|
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
|
||||||
GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32);
|
GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
|
if (ctx->num_additional_fused_ops) {
|
||||||
|
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
|
||||||
|
GGML_ASSERT(idx < num_topk_moe_pipelines);
|
||||||
|
bool with_norm = ctx->num_additional_fused_ops == topk_moe_norm.size() - 1;
|
||||||
|
return ctx->device->pipeline_topk_moe[idx][with_norm];
|
||||||
|
}
|
||||||
|
|
||||||
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
|
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
|
||||||
return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_wg512 : ctx->device->pipeline_soft_max_f32;
|
return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_wg512 : ctx->device->pipeline_soft_max_f32;
|
||||||
}
|
}
|
||||||
|
|
@ -8089,6 +8092,21 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||||
return ctx->device->pipeline_rwkv_wkv7_f32;
|
return ctx->device->pipeline_rwkv_wkv7_f32;
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
case GGML_OP_SSM_SCAN:
|
||||||
|
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||||
|
const uint32_t d_state = src0->ne[0];
|
||||||
|
if (d_state == 128) {
|
||||||
|
return ctx->device->pipeline_ssm_scan_f32_d128;
|
||||||
|
} else if (d_state == 256) {
|
||||||
|
return ctx->device->pipeline_ssm_scan_f32_d256;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
case GGML_OP_SSM_CONV:
|
||||||
|
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||||
|
return ctx->device->pipeline_ssm_conv_f32;
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
case GGML_OP_OPT_STEP_ADAMW:
|
case GGML_OP_OPT_STEP_ADAMW:
|
||||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||||
return ctx->device->pipeline_opt_step_adamw_f32;
|
return ctx->device->pipeline_opt_step_adamw_f32;
|
||||||
|
|
@ -8583,6 +8601,14 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
case GGML_OP_SSM_CONV:
|
||||||
|
{
|
||||||
|
const uint32_t nr = src0->ne[1];
|
||||||
|
const uint32_t n_t = dst->ne[1];
|
||||||
|
const uint32_t n_s = dst->ne[2];
|
||||||
|
elements = { nr, n_t, n_s };
|
||||||
|
}
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
elements = { (uint32_t)ggml_nelements(src0), 1, 1 };
|
elements = { (uint32_t)ggml_nelements(src0), 1, 1 };
|
||||||
break;
|
break;
|
||||||
|
|
@ -9029,6 +9055,117 @@ static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
|
||||||
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
|
const ggml_tensor * src1 = dst->src[1];
|
||||||
|
const ggml_tensor * src2 = dst->src[2];
|
||||||
|
const ggml_tensor * src3 = dst->src[3];
|
||||||
|
const ggml_tensor * src4 = dst->src[4];
|
||||||
|
const ggml_tensor * src5 = dst->src[5];
|
||||||
|
|
||||||
|
GGML_ASSERT(dst->buffer != nullptr);
|
||||||
|
|
||||||
|
const uint32_t head_dim = src0->ne[1];
|
||||||
|
const uint32_t n_head = src1->ne[1];
|
||||||
|
const uint32_t n_group = src4->ne[1];
|
||||||
|
const uint32_t n_tok = src1->ne[2];
|
||||||
|
const uint32_t n_seq = src1->ne[3];
|
||||||
|
|
||||||
|
bool is_mamba2 = (src3->nb[1] == sizeof(float));
|
||||||
|
GGML_ASSERT(is_mamba2);
|
||||||
|
|
||||||
|
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, dst, dst->op);
|
||||||
|
GGML_ASSERT(pipeline != nullptr);
|
||||||
|
|
||||||
|
if (dryrun) {
|
||||||
|
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int64_t s_off = ggml_nelements(src1) * sizeof(float);
|
||||||
|
|
||||||
|
const vk_op_ssm_scan_push_constants pc = {
|
||||||
|
(uint32_t)src0->nb[2], (uint32_t)src0->nb[3],
|
||||||
|
(uint32_t)src1->nb[2], (uint32_t)src1->nb[3],
|
||||||
|
(uint32_t)src2->nb[1], (uint32_t)src2->nb[2],
|
||||||
|
(uint32_t)src3->nb[1],
|
||||||
|
(uint32_t)src4->nb[2], (uint32_t)src4->nb[3],
|
||||||
|
(uint32_t)src5->nb[2], (uint32_t)src5->nb[3],
|
||||||
|
(uint32_t)s_off,
|
||||||
|
n_head, head_dim, n_group, n_tok
|
||||||
|
};
|
||||||
|
|
||||||
|
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
|
||||||
|
ggml_backend_vk_buffer_context * src_buf_ctxs[GGML_MAX_SRC];
|
||||||
|
for (int i = 0; i < GGML_MAX_SRC && dst->src[i] != nullptr; i++) {
|
||||||
|
src_buf_ctxs[i] = (ggml_backend_vk_buffer_context *)dst->src[i]->buffer->context;
|
||||||
|
}
|
||||||
|
|
||||||
|
vk_buffer d_D = nullptr, d_srcs[GGML_MAX_SRC] = { nullptr };
|
||||||
|
size_t dst_offset = 0, src_offsets[GGML_MAX_SRC] = { 0 };
|
||||||
|
bool dst_uma = false, srcs_uma[GGML_MAX_SRC] = { false };
|
||||||
|
|
||||||
|
if (ctx->device->uma) {
|
||||||
|
for (int i = 0; i < GGML_MAX_SRC && dst->src[i] != nullptr; i++) {
|
||||||
|
ggml_vk_host_get(ctx->device, dst->src[i]->data, d_srcs[i], src_offsets[i]);
|
||||||
|
srcs_uma[i] = d_srcs[i] != nullptr;
|
||||||
|
}
|
||||||
|
ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset);
|
||||||
|
dst_uma = d_D != nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!dst_uma) {
|
||||||
|
d_D = dst_buf_ctx->dev_buffer;
|
||||||
|
dst_offset = vk_tensor_offset(dst) + dst->view_offs;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < GGML_MAX_SRC && dst->src[i] != nullptr; i++) {
|
||||||
|
if (!srcs_uma[i]) {
|
||||||
|
d_srcs[i] = src_buf_ctxs[i]->dev_buffer;
|
||||||
|
src_offsets[i] = vk_tensor_offset(dst->src[i]) + dst->src[i]->view_offs;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t dst_size = ggml_nbytes(dst);
|
||||||
|
size_t src_sizes[GGML_MAX_SRC];
|
||||||
|
for (int i = 0; i < GGML_MAX_SRC && dst->src[i] != nullptr; i++) {
|
||||||
|
src_sizes[i] = ggml_nbytes(dst->src[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::array<uint32_t, 3> elements;
|
||||||
|
|
||||||
|
const int splitH = 16;
|
||||||
|
const uint32_t num_workgroups_x = CEIL_DIV(n_head * head_dim, splitH);
|
||||||
|
const uint32_t num_workgroups_y = n_seq;
|
||||||
|
elements = { num_workgroups_x, num_workgroups_y, 1 };
|
||||||
|
|
||||||
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
|
||||||
|
vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] },
|
||||||
|
vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] },
|
||||||
|
vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] },
|
||||||
|
vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] },
|
||||||
|
vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] },
|
||||||
|
vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] },
|
||||||
|
vk_subbuffer{ d_srcs[6], src_offsets[6], src_sizes[6] },
|
||||||
|
vk_subbuffer{ d_D, dst_offset, dst_size }
|
||||||
|
}, pc, elements);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_vk_ssm_conv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
|
||||||
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
|
const ggml_tensor * src1 = dst->src[1];
|
||||||
|
|
||||||
|
ggml_vk_op_f32<vk_op_ssm_conv_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SSM_CONV, {
|
||||||
|
(uint32_t)src0->nb[1], (uint32_t)src0->nb[2],
|
||||||
|
(uint32_t)src1->nb[1],
|
||||||
|
(uint32_t)dst->nb[0], (uint32_t)dst->nb[1], (uint32_t)dst->nb[2],
|
||||||
|
(uint32_t)src1->ne[0],
|
||||||
|
(uint32_t)src0->ne[0],
|
||||||
|
(uint32_t)src0->ne[1],
|
||||||
|
(uint32_t)dst->ne[1],
|
||||||
|
(uint32_t)dst->ne[2],
|
||||||
|
}, dryrun);
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_vk_op_f32_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_push_constants&& pc, bool dryrun = false) {
|
static void ggml_vk_op_f32_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_push_constants&& pc, bool dryrun = false) {
|
||||||
const ggml_tensor * x = dst->src[0];
|
const ggml_tensor * x = dst->src[0];
|
||||||
const ggml_tensor * g = dst->src[1];
|
const ggml_tensor * g = dst->src[1];
|
||||||
|
|
@ -9425,6 +9562,87 @@ static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& sub
|
||||||
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), op_params[0], op_params[1] }, dryrun);
|
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), op_params[0], op_params[1] }, dryrun);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx, bool dryrun = false) {
|
||||||
|
|
||||||
|
bool with_norm = ctx->num_additional_fused_ops == topk_moe_norm.size() - 1;
|
||||||
|
ggml_tensor * logits = cgraph->nodes[node_idx + 0]->src[0];
|
||||||
|
ggml_tensor * weights = with_norm ? cgraph->nodes[node_idx + 8] : cgraph->nodes[node_idx + 4];
|
||||||
|
ggml_tensor * ids = cgraph->nodes[node_idx + 3];
|
||||||
|
|
||||||
|
GGML_ASSERT(logits->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(weights->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(ids->type == GGML_TYPE_I32);
|
||||||
|
|
||||||
|
const int n_experts = logits->ne[0];
|
||||||
|
const int n_rows = logits->ne[1];
|
||||||
|
const int n_expert_used = weights->ne[1];
|
||||||
|
|
||||||
|
GGML_ASSERT(ids->nb[1] / ggml_type_size(ids->type) == (size_t) n_experts);
|
||||||
|
|
||||||
|
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, nullptr, nullptr, nullptr, cgraph->nodes[node_idx], GGML_OP_SOFT_MAX);
|
||||||
|
|
||||||
|
if (dryrun) {
|
||||||
|
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_backend_vk_buffer_context * logits_buf_ctx = (ggml_backend_vk_buffer_context *)logits->buffer->context;
|
||||||
|
ggml_backend_vk_buffer_context * weights_buf_ctx = (ggml_backend_vk_buffer_context *)weights->buffer->context;
|
||||||
|
ggml_backend_vk_buffer_context * ids_buf_ctx = (ggml_backend_vk_buffer_context *)ids->buffer->context;
|
||||||
|
|
||||||
|
vk_buffer d_logits = nullptr;
|
||||||
|
size_t logits_buf_offset = 0;
|
||||||
|
vk_buffer d_weights = nullptr;
|
||||||
|
size_t weights_buf_offset = 0;
|
||||||
|
vk_buffer d_ids = nullptr;
|
||||||
|
size_t ids_buf_offset = 0;
|
||||||
|
|
||||||
|
bool logits_uma = false;
|
||||||
|
bool weights_uma = false;
|
||||||
|
bool ids_uma = false;
|
||||||
|
|
||||||
|
if (ctx->device->uma) {
|
||||||
|
ggml_vk_host_get(ctx->device, logits->data, d_logits, logits_buf_offset);
|
||||||
|
ggml_vk_host_get(ctx->device, weights->data, d_weights, weights_buf_offset);
|
||||||
|
ggml_vk_host_get(ctx->device, ids->data, d_ids, ids_buf_offset);
|
||||||
|
logits_uma = d_logits != nullptr;
|
||||||
|
weights_uma = d_weights != nullptr;
|
||||||
|
ids_uma = d_ids != nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!logits_uma) {
|
||||||
|
d_logits = logits_buf_ctx->dev_buffer;
|
||||||
|
logits_buf_offset = vk_tensor_offset(logits) + logits->view_offs;
|
||||||
|
GGML_ASSERT(d_logits != nullptr);
|
||||||
|
}
|
||||||
|
if (!weights_uma) {
|
||||||
|
d_weights = weights_buf_ctx->dev_buffer;
|
||||||
|
weights_buf_offset = vk_tensor_offset(weights) + weights->view_offs;
|
||||||
|
GGML_ASSERT(d_weights != nullptr);
|
||||||
|
}
|
||||||
|
if (!ids_uma) {
|
||||||
|
d_ids = ids_buf_ctx->dev_buffer;
|
||||||
|
ids_buf_offset = vk_tensor_offset(ids) + ids->view_offs;
|
||||||
|
GGML_ASSERT(d_ids != nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
vk_op_topk_moe_push_constants pc;
|
||||||
|
pc.n_rows = n_rows;
|
||||||
|
pc.n_expert_used = n_expert_used;
|
||||||
|
|
||||||
|
GGML_ASSERT(n_expert_used <= n_experts);
|
||||||
|
|
||||||
|
const uint32_t rows_per_block = 4;
|
||||||
|
std::array<uint32_t, 3> elements = { CEIL_DIV(n_rows, rows_per_block), 1, 1 };
|
||||||
|
|
||||||
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
||||||
|
{
|
||||||
|
ggml_vk_subbuffer(ctx, d_logits, logits_buf_offset),
|
||||||
|
ggml_vk_subbuffer(ctx, d_weights, weights_buf_offset),
|
||||||
|
ggml_vk_subbuffer(ctx, d_ids, ids_buf_offset),
|
||||||
|
}, pc, elements);
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool backprop, bool dryrun = false) {
|
static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool backprop, bool dryrun = false) {
|
||||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||||
const int mode = ((int32_t *) dst->op_params)[2];
|
const int mode = ((int32_t *) dst->op_params)[2];
|
||||||
|
|
@ -10861,6 +11079,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||||
case GGML_OP_CONV_2D_DW:
|
case GGML_OP_CONV_2D_DW:
|
||||||
case GGML_OP_RWKV_WKV6:
|
case GGML_OP_RWKV_WKV6:
|
||||||
case GGML_OP_RWKV_WKV7:
|
case GGML_OP_RWKV_WKV7:
|
||||||
|
case GGML_OP_SSM_SCAN:
|
||||||
|
case GGML_OP_SSM_CONV:
|
||||||
case GGML_OP_LEAKY_RELU:
|
case GGML_OP_LEAKY_RELU:
|
||||||
case GGML_OP_FLASH_ATTN_EXT:
|
case GGML_OP_FLASH_ATTN_EXT:
|
||||||
case GGML_OP_OPT_STEP_ADAMW:
|
case GGML_OP_OPT_STEP_ADAMW:
|
||||||
|
|
@ -11008,11 +11228,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||||
ctx->unsynced_nodes_read.clear();
|
ctx->unsynced_nodes_read.clear();
|
||||||
ggml_vk_sync_buffers(ctx, compute_ctx);
|
ggml_vk_sync_buffers(ctx, compute_ctx);
|
||||||
}
|
}
|
||||||
// Add the last fused node and all fused source nodes to the unsynchronized list.
|
// Add all fused nodes to the unsynchronized lists.
|
||||||
const ggml_tensor * last_node = cgraph->nodes[node_idx + ctx->num_additional_fused_ops];
|
|
||||||
ctx->unsynced_nodes_written.push_back(last_node);
|
|
||||||
for (int32_t i = 0; i < ctx->num_additional_fused_ops + 1; ++i) {
|
for (int32_t i = 0; i < ctx->num_additional_fused_ops + 1; ++i) {
|
||||||
const ggml_tensor *cur_node = cgraph->nodes[node_idx + i];
|
const ggml_tensor *cur_node = cgraph->nodes[node_idx + i];
|
||||||
|
// Multiple outputs could be written, e.g. in topk_moe. Add them all to the list.
|
||||||
|
ctx->unsynced_nodes_written.push_back(cur_node);
|
||||||
for (uint32_t j = 0; j < GGML_MAX_SRC; ++j) {
|
for (uint32_t j = 0; j < GGML_MAX_SRC; ++j) {
|
||||||
if (!cur_node->src[j]) {
|
if (!cur_node->src[j]) {
|
||||||
continue;
|
continue;
|
||||||
|
|
@ -11179,7 +11399,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||||
|
|
||||||
break;
|
break;
|
||||||
case GGML_OP_SOFT_MAX:
|
case GGML_OP_SOFT_MAX:
|
||||||
|
if (ctx->num_additional_fused_ops) {
|
||||||
|
ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx, dryrun);
|
||||||
|
} else {
|
||||||
ggml_vk_soft_max(ctx, compute_ctx, src0, src1, src2, node, dryrun);
|
ggml_vk_soft_max(ctx, compute_ctx, src0, src1, src2, node, dryrun);
|
||||||
|
}
|
||||||
|
|
||||||
break;
|
break;
|
||||||
case GGML_OP_SOFT_MAX_BACK:
|
case GGML_OP_SOFT_MAX_BACK:
|
||||||
|
|
@ -11278,6 +11502,16 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||||
|
|
||||||
break;
|
break;
|
||||||
|
|
||||||
|
case GGML_OP_SSM_SCAN:
|
||||||
|
ggml_vk_ssm_scan(ctx, compute_ctx, node, dryrun);
|
||||||
|
|
||||||
|
break;
|
||||||
|
|
||||||
|
case GGML_OP_SSM_CONV:
|
||||||
|
ggml_vk_ssm_conv(ctx, compute_ctx, node, dryrun);
|
||||||
|
|
||||||
|
break;
|
||||||
|
|
||||||
case GGML_OP_OPT_STEP_ADAMW:
|
case GGML_OP_OPT_STEP_ADAMW:
|
||||||
ggml_vk_opt_step_adamw(ctx, compute_ctx, node, dryrun);
|
ggml_vk_opt_step_adamw(ctx, compute_ctx, node, dryrun);
|
||||||
|
|
||||||
|
|
@ -11389,6 +11623,8 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
|
||||||
case GGML_OP_CONV_2D_DW:
|
case GGML_OP_CONV_2D_DW:
|
||||||
case GGML_OP_RWKV_WKV6:
|
case GGML_OP_RWKV_WKV6:
|
||||||
case GGML_OP_RWKV_WKV7:
|
case GGML_OP_RWKV_WKV7:
|
||||||
|
case GGML_OP_SSM_SCAN:
|
||||||
|
case GGML_OP_SSM_CONV:
|
||||||
case GGML_OP_LEAKY_RELU:
|
case GGML_OP_LEAKY_RELU:
|
||||||
case GGML_OP_REPEAT:
|
case GGML_OP_REPEAT:
|
||||||
case GGML_OP_REPEAT_BACK:
|
case GGML_OP_REPEAT_BACK:
|
||||||
|
|
@ -11498,10 +11734,6 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
|
||||||
// Clean up after graph processing is done
|
// Clean up after graph processing is done
|
||||||
static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) {
|
static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) {
|
||||||
VK_LOG_DEBUG("ggml_vk_graph_cleanup()");
|
VK_LOG_DEBUG("ggml_vk_graph_cleanup()");
|
||||||
for (auto& buffer : ctx->gc.temp_buffers) {
|
|
||||||
ggml_vk_pool_free(ctx, buffer);
|
|
||||||
}
|
|
||||||
ctx->gc.temp_buffers.clear();
|
|
||||||
ctx->prealloc_y_last_pipeline_used = {};
|
ctx->prealloc_y_last_pipeline_used = {};
|
||||||
|
|
||||||
ctx->unsynced_nodes_written.clear();
|
ctx->unsynced_nodes_written.clear();
|
||||||
|
|
@ -11544,10 +11776,6 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) {
|
||||||
ggml_vk_destroy_buffer(ctx->prealloc_split_k);
|
ggml_vk_destroy_buffer(ctx->prealloc_split_k);
|
||||||
ctx->prealloc_y_last_pipeline_used = nullptr;
|
ctx->prealloc_y_last_pipeline_used = nullptr;
|
||||||
|
|
||||||
for (auto& buffer : ctx->buffer_pool) {
|
|
||||||
ggml_vk_destroy_buffer(buffer);
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx->prealloc_size_x = 0;
|
ctx->prealloc_size_x = 0;
|
||||||
ctx->prealloc_size_y = 0;
|
ctx->prealloc_size_y = 0;
|
||||||
ctx->prealloc_size_split_k = 0;
|
ctx->prealloc_size_split_k = 0;
|
||||||
|
|
@ -11988,6 +12216,120 @@ static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, st
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph,
|
||||||
|
int node_idx, bool with_norm) {
|
||||||
|
|
||||||
|
if (with_norm) {
|
||||||
|
if (node_idx + (int)topk_moe_norm.size() > cgraph->n_nodes) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < topk_moe_norm.size(); ++i) {
|
||||||
|
if (cgraph->nodes[node_idx + i]->op != topk_moe_norm[i]) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (node_idx + (int)topk_moe.size() > cgraph->n_nodes) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < topk_moe.size(); ++i) {
|
||||||
|
if (cgraph->nodes[node_idx + i]->op != topk_moe[i]) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const ggml_tensor * softmax = cgraph->nodes[node_idx + 0];
|
||||||
|
const ggml_tensor * weights = with_norm ? cgraph->nodes[node_idx + 8] : cgraph->nodes[node_idx + 4];
|
||||||
|
|
||||||
|
const float * op_params = (const float *)softmax->op_params;
|
||||||
|
|
||||||
|
float scale = op_params[0];
|
||||||
|
float max_bias = op_params[1];
|
||||||
|
|
||||||
|
if (!ggml_is_contiguous(softmax->src[0]) || !ggml_is_contiguous(weights)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (scale != 1.0f || max_bias != 0.0f) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// don't fuse when masks or sinks are present
|
||||||
|
if (softmax->src[1] || softmax->src[2]) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int n_expert = softmax->ne[0];
|
||||||
|
// n_expert must be a power of 2
|
||||||
|
if (!is_pow2(n_expert) || n_expert > (1 << (num_topk_moe_pipelines-1))) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that the nodes don't have any unexpected uses
|
||||||
|
const ggml_tensor * reshape1 = cgraph->nodes[node_idx + 1];
|
||||||
|
const ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
|
||||||
|
const ggml_tensor * view = cgraph->nodes[node_idx + 3];
|
||||||
|
const ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
|
||||||
|
const ggml_tensor * reshape5 = with_norm ? cgraph->nodes[node_idx + 5] : nullptr;
|
||||||
|
const ggml_tensor * sum_rows = with_norm ? cgraph->nodes[node_idx + 6] : nullptr;
|
||||||
|
const ggml_tensor * div = with_norm ? cgraph->nodes[node_idx + 7] : nullptr;
|
||||||
|
const ggml_tensor * reshape8 = with_norm ? cgraph->nodes[node_idx + 8] : nullptr;
|
||||||
|
|
||||||
|
// softmax is used by reshape and argsort
|
||||||
|
if (ggml_node_get_use_count(cgraph, node_idx) != 2 ||
|
||||||
|
reshape1->src[0] != softmax ||
|
||||||
|
argsort->src[0] != softmax) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
// reshape is used by get_rows
|
||||||
|
if (ggml_node_get_use_count(cgraph, node_idx + 1) != 1 ||
|
||||||
|
get_rows->src[0] != reshape1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
// argsort is used by view
|
||||||
|
if (ggml_node_get_use_count(cgraph, node_idx + 2) != 1 ||
|
||||||
|
view->src[0] != argsort) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
// view is written (via argsort), we can skip checking it
|
||||||
|
|
||||||
|
if (with_norm) {
|
||||||
|
// get_rows is used by reshape
|
||||||
|
if (ggml_node_get_use_count(cgraph, node_idx + 4) != 1 ||
|
||||||
|
reshape5->src[0] != get_rows) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// reshape is used by sum_rows and div
|
||||||
|
if (ggml_node_get_use_count(cgraph, node_idx + 5) != 2 ||
|
||||||
|
sum_rows->src[0] != reshape5 ||
|
||||||
|
div->src[0] != reshape5) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// sum_rows is used by div
|
||||||
|
if (ggml_node_get_use_count(cgraph, node_idx + 6) != 1 ||
|
||||||
|
div->src[1] != sum_rows) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// div/reshape are written
|
||||||
|
if (reshape8->src[0] != div) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!ctx->device->subgroup_arithmetic ||
|
||||||
|
!ctx->device->subgroup_shuffle ||
|
||||||
|
!ctx->device->subgroup_require_full_support ||
|
||||||
|
ctx->device->disable_fusion) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
static uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx) {
|
static uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx) {
|
||||||
|
|
||||||
const ggml_tensor *first_node = cgraph->nodes[node_idx];
|
const ggml_tensor *first_node = cgraph->nodes[node_idx];
|
||||||
|
|
@ -12063,6 +12405,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
||||||
ctx->num_additional_fused_ops = num_adds - 1;
|
ctx->num_additional_fused_ops = num_adds - 1;
|
||||||
} else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
|
} else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
|
||||||
ctx->num_additional_fused_ops = 1;
|
ctx->num_additional_fused_ops = 1;
|
||||||
|
} else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, true)) {
|
||||||
|
ctx->num_additional_fused_ops = topk_moe_norm.size() - 1;
|
||||||
|
} else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, false)) {
|
||||||
|
ctx->num_additional_fused_ops = topk_moe.size() - 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
|
ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
|
||||||
|
|
@ -12160,6 +12506,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
||||||
ctx->num_additional_fused_ops = num_adds - 1;
|
ctx->num_additional_fused_ops = num_adds - 1;
|
||||||
} else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
|
} else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
|
||||||
ctx->num_additional_fused_ops = 1;
|
ctx->num_additional_fused_ops = 1;
|
||||||
|
} else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, true)) {
|
||||||
|
ctx->num_additional_fused_ops = topk_moe_norm.size() - 1;
|
||||||
|
} else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, false)) {
|
||||||
|
ctx->num_additional_fused_ops = topk_moe.size() - 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -12167,10 +12517,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
||||||
bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5;
|
bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5;
|
||||||
bool submit = (submitted_nodes >= nodes_per_submit) ||
|
bool submit = (submitted_nodes >= nodes_per_submit) ||
|
||||||
(mul_mat_bytes >= mul_mat_bytes_per_submit) ||
|
(mul_mat_bytes >= mul_mat_bytes_per_submit) ||
|
||||||
(i + ctx->num_additional_fused_ops == last_node) ||
|
(i + ctx->num_additional_fused_ops >= last_node) ||
|
||||||
(almost_ready && !ctx->almost_ready_fence_pending);
|
(almost_ready && !ctx->almost_ready_fence_pending);
|
||||||
|
|
||||||
bool enqueued = ggml_vk_build_graph(ctx, cgraph, i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i + ctx->num_additional_fused_ops == last_node, almost_ready, submit);
|
bool enqueued = ggml_vk_build_graph(ctx, cgraph, i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i + ctx->num_additional_fused_ops >= last_node, almost_ready, submit);
|
||||||
|
|
||||||
if (vk_perf_logger_enabled) {
|
if (vk_perf_logger_enabled) {
|
||||||
if (ctx->compute_ctx.expired()) {
|
if (ctx->compute_ctx.expired()) {
|
||||||
|
|
@ -12292,6 +12642,25 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
|
||||||
while (first_unused < graph->n_nodes) {
|
while (first_unused < graph->n_nodes) {
|
||||||
std::vector<int> current_set;
|
std::vector<int> current_set;
|
||||||
|
|
||||||
|
// Avoid reordering topk_moe_norm
|
||||||
|
if (first_unused + (int)topk_moe_norm.size() <= graph->n_nodes) {
|
||||||
|
bool is_topk_moe_norm = true;
|
||||||
|
for (size_t j = 0; j < topk_moe_norm.size(); ++j) {
|
||||||
|
if (graph->nodes[first_unused + j]->op != topk_moe_norm[j] || used[first_unused + j]) {
|
||||||
|
is_topk_moe_norm = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (is_topk_moe_norm) {
|
||||||
|
for (size_t j = 0; j < topk_moe_norm.size(); ++j) {
|
||||||
|
new_order.push_back(graph->nodes[first_unused + j]);
|
||||||
|
used[first_unused + j] = true;
|
||||||
|
}
|
||||||
|
while (first_unused < graph->n_nodes && used[first_unused]) {
|
||||||
|
first_unused++;
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
// First, grab the next unused node.
|
// First, grab the next unused node.
|
||||||
current_set.push_back(first_unused);
|
current_set.push_back(first_unused);
|
||||||
|
|
||||||
|
|
@ -12797,6 +13166,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||||
}
|
}
|
||||||
switch (op->src[1]->type) {
|
switch (op->src[1]->type) {
|
||||||
case GGML_TYPE_F16:
|
case GGML_TYPE_F16:
|
||||||
|
case GGML_TYPE_F32:
|
||||||
case GGML_TYPE_Q4_0:
|
case GGML_TYPE_Q4_0:
|
||||||
case GGML_TYPE_Q8_0:
|
case GGML_TYPE_Q8_0:
|
||||||
// supported in scalar and coopmat2 paths
|
// supported in scalar and coopmat2 paths
|
||||||
|
|
@ -13004,6 +13374,47 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||||
case GGML_OP_RWKV_WKV6:
|
case GGML_OP_RWKV_WKV6:
|
||||||
case GGML_OP_RWKV_WKV7:
|
case GGML_OP_RWKV_WKV7:
|
||||||
return true;
|
return true;
|
||||||
|
case GGML_OP_SSM_SCAN:
|
||||||
|
{
|
||||||
|
for (int i = 0; i < 6; i++) {
|
||||||
|
if (op->src[i] && ggml_is_quantized(op->src[i]->type)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (op->src[6] && op->src[6]->type != GGML_TYPE_I32) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (op->src[0]->type != GGML_TYPE_F32 || op->type != GGML_TYPE_F32) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint32_t d_state = op->src[0]->ne[0];
|
||||||
|
const uint32_t head_dim = op->src[0]->ne[1];
|
||||||
|
|
||||||
|
bool is_mamba2 = (op->src[3] && op->src[3]->nb[1] == sizeof(float));
|
||||||
|
if (!is_mamba2) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((d_state != 128 && d_state != 256) || head_dim % 16 != 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
||||||
|
const vk_device& device = ggml_vk_get_device(ctx->device);
|
||||||
|
|
||||||
|
const uint32_t SPLIT_H = 16;
|
||||||
|
|
||||||
|
size_t stateC_size = SPLIT_H * d_state * sizeof(float);
|
||||||
|
|
||||||
|
if (stateC_size > device->properties.limits.maxComputeSharedMemorySize) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
case GGML_OP_SSM_CONV:
|
||||||
|
return true;
|
||||||
case GGML_OP_CONV_TRANSPOSE_1D:
|
case GGML_OP_CONV_TRANSPOSE_1D:
|
||||||
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
|
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
|
||||||
case GGML_OP_CONV_2D:
|
case GGML_OP_CONV_2D:
|
||||||
|
|
@ -13386,14 +13797,14 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
|
||||||
|
|
||||||
struct ggml_context * ggml_ctx = ggml_init(iparams);
|
struct ggml_context * ggml_ctx = ggml_init(iparams);
|
||||||
|
|
||||||
std::array<struct ggml_tensor *, 6> src_clone = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr};
|
std::array<struct ggml_tensor *, GGML_MAX_SRC> src_clone = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr};
|
||||||
std::array<size_t, 6> src_size = {0, 0, 0, 0, 0, 0};
|
std::array<size_t, GGML_MAX_SRC> src_size = {};
|
||||||
std::array<void *, 6> src_buffer = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr};
|
std::array<void *, GGML_MAX_SRC> src_buffer = {};
|
||||||
const char * srci_name[6] = {"src0", "src1", "src2", "src3", "src4", "src5"};
|
const char * srci_name[GGML_MAX_SRC] = {"src0", "src1", "src2", "src3", "src4", "src5", "src6", "src7", "src8", "src9"};
|
||||||
|
|
||||||
struct ggml_tensor * tensor_clone = nullptr;
|
struct ggml_tensor * tensor_clone = nullptr;
|
||||||
|
|
||||||
for (int i = 0; i < 6; i++) {
|
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
||||||
ggml_tensor * srci = tensor->src[i];
|
ggml_tensor * srci = tensor->src[i];
|
||||||
if (fused_rms_norm_mul) {
|
if (fused_rms_norm_mul) {
|
||||||
rms_norm_idx = tensor->src[0]->op == GGML_OP_RMS_NORM ? 0 : 1;
|
rms_norm_idx = tensor->src[0]->op == GGML_OP_RMS_NORM ? 0 : 1;
|
||||||
|
|
@ -13700,6 +14111,11 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
|
||||||
src_clone[2]);
|
src_clone[2]);
|
||||||
} else if (tensor->op == GGML_OP_ADD_ID) {
|
} else if (tensor->op == GGML_OP_ADD_ID) {
|
||||||
tensor_clone = ggml_add_id(ggml_ctx, src_clone[0], src_clone[1], src_clone[2]);
|
tensor_clone = ggml_add_id(ggml_ctx, src_clone[0], src_clone[1], src_clone[2]);
|
||||||
|
} else if (tensor->op == GGML_OP_SSM_SCAN) {
|
||||||
|
tensor_clone = ggml_ssm_scan(ggml_ctx, src_clone[0], src_clone[1], src_clone[2],
|
||||||
|
src_clone[3], src_clone[4], src_clone[5], src_clone[6]);
|
||||||
|
} else if (tensor->op == GGML_OP_SSM_CONV) {
|
||||||
|
tensor_clone = ggml_ssm_conv(ggml_ctx, src_clone[0], src_clone[1]);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
|
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
|
||||||
|
|
@ -13721,7 +14137,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
|
||||||
memcpy(comp_result, tensor_clone->data, comp_size);
|
memcpy(comp_result, tensor_clone->data, comp_size);
|
||||||
memcpy(comp_nb, tensor_clone->nb, sizeof(size_t) * GGML_MAX_DIMS);
|
memcpy(comp_nb, tensor_clone->nb, sizeof(size_t) * GGML_MAX_DIMS);
|
||||||
|
|
||||||
for (int i = 0; i < 6; i++) {
|
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
||||||
if (src_buffer[i] != nullptr) {
|
if (src_buffer[i] != nullptr) {
|
||||||
free(src_buffer[i]);
|
free(src_buffer[i]);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,18 @@
|
||||||
|
|
||||||
#include "types.glsl"
|
#include "types.glsl"
|
||||||
|
|
||||||
|
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufF32 {
|
||||||
|
vec4 block;
|
||||||
|
};
|
||||||
|
|
||||||
|
float16_t dequantFuncF32(const in decodeBufF32 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||||
|
{
|
||||||
|
const vec4 v = bl.block;
|
||||||
|
const uint idx = coordInBlock[1];
|
||||||
|
const f16vec4 vf16 = f16vec4(v);
|
||||||
|
return vf16[idx];
|
||||||
|
}
|
||||||
|
|
||||||
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ4_0 {
|
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ4_0 {
|
||||||
block_q4_0_packed16 block;
|
block_q4_0_packed16 block;
|
||||||
};
|
};
|
||||||
|
|
@ -717,4 +729,6 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords
|
||||||
#define dequantFuncA dequantFuncIQ4_NL
|
#define dequantFuncA dequantFuncIQ4_NL
|
||||||
#elif defined(DATA_A_MXFP4)
|
#elif defined(DATA_A_MXFP4)
|
||||||
#define dequantFuncA dequantFuncMXFP4
|
#define dequantFuncA dequantFuncMXFP4
|
||||||
|
#elif defined(DATA_A_F32)
|
||||||
|
#define dequantFuncA dequantFuncF32
|
||||||
#endif
|
#endif
|
||||||
|
|
|
||||||
|
|
@ -345,7 +345,7 @@ void main() {
|
||||||
|
|
||||||
float Lfrcp[Br];
|
float Lfrcp[Br];
|
||||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||||
Lfrcp[r] = 1.0 / Lf[r];
|
Lfrcp[r] = (Lf[r] == 0.0) ? 0.0 : (1.0 / Lf[r]);
|
||||||
}
|
}
|
||||||
|
|
||||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||||
|
|
|
||||||
|
|
@ -64,13 +64,31 @@ layout (binding = 4) readonly buffer S {float data_s[];};
|
||||||
|
|
||||||
layout (binding = 5) writeonly buffer O {D_TYPE data_o[];};
|
layout (binding = 5) writeonly buffer O {D_TYPE data_o[];};
|
||||||
|
|
||||||
#if defined(A_TYPE_PACKED16)
|
|
||||||
#define BINDING_IDX_K 0
|
#define BINDING_IDX_K 0
|
||||||
#define BINDING_IDX_V 1
|
#define BINDING_IDX_V 1
|
||||||
|
#if defined(DATA_A_F32)
|
||||||
|
layout (binding = 1) readonly buffer K_PACKED {vec4 k_data_packed[];} k_packed;
|
||||||
|
layout (binding = 2) readonly buffer V_PACKED {vec4 v_data_packed[];} v_packed;
|
||||||
|
#elif defined(A_TYPE_PACKED16)
|
||||||
layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16[];} k_packed;
|
layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16[];} k_packed;
|
||||||
layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed;
|
layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if defined(DATA_A_F32)
|
||||||
|
#undef BLOCK_SIZE
|
||||||
|
#define BLOCK_SIZE 4
|
||||||
|
#define BLOCK_BYTE_SIZE 16
|
||||||
|
|
||||||
|
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
||||||
|
// iqs is currently always zero in the flash attention shaders
|
||||||
|
if (binding_idx == BINDING_IDX_K) {
|
||||||
|
return k_packed.k_data_packed[a_offset + ib];
|
||||||
|
} else {
|
||||||
|
return v_packed.v_data_packed[a_offset + ib];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
#if defined(DATA_A_Q4_0)
|
#if defined(DATA_A_Q4_0)
|
||||||
#define BLOCK_BYTE_SIZE 18
|
#define BLOCK_BYTE_SIZE 18
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -380,7 +380,7 @@ void main() {
|
||||||
|
|
||||||
float Lfrcp[rows_per_thread];
|
float Lfrcp[rows_per_thread];
|
||||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||||
Lfrcp[r] = 1.0 / Lf[r];
|
Lfrcp[r] = (Lf[r] == 0.0) ? 0.0 : (1.0 / Lf[r]);
|
||||||
}
|
}
|
||||||
|
|
||||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||||
|
|
|
||||||
|
|
@ -121,7 +121,11 @@ void main() {
|
||||||
const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF);
|
const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF);
|
||||||
|
|
||||||
L = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
|
L = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
|
||||||
|
#if defined(ACC_TYPE_MAX)
|
||||||
|
M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(-ACC_TYPE_MAX / ACC_TYPE(2));
|
||||||
|
#else
|
||||||
M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(NEG_FLT_MAX_OVER_2);
|
M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(NEG_FLT_MAX_OVER_2);
|
||||||
|
#endif
|
||||||
|
|
||||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> slopeMat = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(1.0);
|
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> slopeMat = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(1.0);
|
||||||
|
|
||||||
|
|
@ -294,7 +298,7 @@ void main() {
|
||||||
|
|
||||||
[[unroll]]
|
[[unroll]]
|
||||||
for (int k = 0; k < Ldiag.length(); ++k) {
|
for (int k = 0; k < Ldiag.length(); ++k) {
|
||||||
Ldiag[k] = ACC_TYPE(1.0) / Ldiag[k];
|
Ldiag[k] = (Ldiag[k] == 0.0) ? ACC_TYPE(0.0) : (ACC_TYPE(1.0) / Ldiag[k]);
|
||||||
}
|
}
|
||||||
|
|
||||||
O = Ldiag*O;
|
O = Ldiag*O;
|
||||||
|
|
|
||||||
|
|
@ -91,7 +91,7 @@ void main() {
|
||||||
L = L*ms + vs;
|
L = L*ms + vs;
|
||||||
}
|
}
|
||||||
|
|
||||||
L = 1.0 / L;
|
L = (L == 0.0) ? 0.0 : 1.0 / L;
|
||||||
|
|
||||||
// D dimension is split across workgroups in the y dimension
|
// D dimension is split across workgroups in the y dimension
|
||||||
uint d = tid + gl_WorkGroupID.y * BLOCK_SIZE;
|
uint d = tid + gl_WorkGroupID.y * BLOCK_SIZE;
|
||||||
|
|
|
||||||
|
|
@ -313,12 +313,12 @@ void main() {
|
||||||
sums[i] = coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0f);
|
sums[i] = coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0f);
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
ACC_TYPE sums[WMITER * TM * WNITER * TN];
|
ACC_TYPE_VEC2 sums[WMITER * TM * WNITER * TN/2];
|
||||||
FLOAT_TYPE_VEC2 cache_a[WMITER * TM];
|
FLOAT_TYPE_VEC2 cache_a[WMITER * TM];
|
||||||
FLOAT_TYPE_VEC2 cache_b[TN];
|
FLOAT_TYPE_VEC2 cache_b;
|
||||||
|
|
||||||
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
|
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN/2; i++) {
|
||||||
sums[i] = ACC_TYPE(0.0f);
|
sums[i] = ACC_TYPE_VEC2(0.0f, 0.0f);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
@ -360,20 +360,22 @@ void main() {
|
||||||
cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + i];
|
cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
|
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
|
||||||
[[unroll]] for (uint j = 0; j < TN; j++) {
|
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
|
||||||
cache_b[j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * SHMEM_STRIDE + i];
|
cache_b = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + cc) * SHMEM_STRIDE + i];
|
||||||
}
|
|
||||||
|
|
||||||
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
|
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
|
||||||
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
|
[[unroll]] for (uint cr = 0; cr < TM / 2; cr++) {
|
||||||
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
|
// [WNITER][TN][WMITER][TM / 2] -> [wsic][cc][wsir][cr]
|
||||||
const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
|
const uint sums_idx = (wsic * TN + cc) * WMITER * (TM / 2) + wsir * (TM / 2) + cr;
|
||||||
sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr].x), ACC_TYPE(cache_b[cc].x), fma(ACC_TYPE(cache_a[wsir * TM + cr].y), ACC_TYPE(cache_b[cc].y), sums[sums_idx]));
|
sums[sums_idx].x = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].y), ACC_TYPE(cache_b.y), sums[sums_idx].x));
|
||||||
|
sums[sums_idx].y = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].y), ACC_TYPE(cache_b.y), sums[sums_idx].y));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
@ -388,8 +390,9 @@ void main() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
|
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN/2; i++) {
|
||||||
sums[i] = clamp(sums[i], -ACC_TYPE_MAX, ACC_TYPE_MAX);
|
sums[i].x = clamp(sums[i].x, -ACC_TYPE_MAX, ACC_TYPE_MAX);
|
||||||
|
sums[i].y = clamp(sums[i].y, -ACC_TYPE_MAX, ACC_TYPE_MAX);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
#endif
|
#endif
|
||||||
|
|
@ -463,14 +466,21 @@ void main() {
|
||||||
|
|
||||||
const u16vec2 row_idx = row_ids[row_i - ic * BN];
|
const u16vec2 row_idx = row_ids[row_i - ic * BN];
|
||||||
#endif // MUL_MAT_ID
|
#endif // MUL_MAT_ID
|
||||||
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
|
[[unroll]] for (uint cr = 0; cr < TM / 2; cr++) {
|
||||||
|
const uint sums_idx = (wsic * TN + cc) * WMITER * (TM / 2) + wsir * (TM / 2) + cr;
|
||||||
#ifdef MUL_MAT_ID
|
#ifdef MUL_MAT_ID
|
||||||
if (dr_warp + cr < p.M) {
|
if (dr_warp + 2 * cr < p.M) {
|
||||||
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
|
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + 2 * cr] = D_TYPE(sums[sums_idx].x);
|
||||||
|
}
|
||||||
|
if (dr_warp + 2 * cr + 1 < p.M) {
|
||||||
|
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + 2 * cr + 1] = D_TYPE(sums[sums_idx].y);
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
|
if (dr_warp + 2 * cr < p.M && dc_warp + cc < p.N) {
|
||||||
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
|
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + 2 * cr] = D_TYPE(sums[sums_idx].x);
|
||||||
|
}
|
||||||
|
if (dr_warp + 2 * cr + 1 < p.M && dc_warp + cc < p.N) {
|
||||||
|
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + 2 * cr + 1] = D_TYPE(sums[sums_idx].y);
|
||||||
}
|
}
|
||||||
#endif // MUL_MAT_ID
|
#endif // MUL_MAT_ID
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,44 @@
|
||||||
|
#version 450
|
||||||
|
|
||||||
|
#extension GL_EXT_control_flow_attributes : require
|
||||||
|
|
||||||
|
#include "types.glsl"
|
||||||
|
|
||||||
|
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
|
||||||
|
|
||||||
|
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
layout(binding = 0) readonly buffer Src0 { float src0[]; };
|
||||||
|
layout(binding = 1) readonly buffer Src1 { float src1[]; };
|
||||||
|
layout(binding = 2) buffer Dst { float dst[]; };
|
||||||
|
|
||||||
|
layout(push_constant) uniform PushConstants {
|
||||||
|
uint nb01; uint nb02;
|
||||||
|
uint nb11;
|
||||||
|
uint dst_nb0; uint dst_nb1; uint dst_nb2;
|
||||||
|
uint nc; uint ncs; uint nr; uint n_t; uint n_s;
|
||||||
|
};
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
const uint global_thread_id = gl_GlobalInvocationID.x;
|
||||||
|
const uint i2 = gl_WorkGroupID.y;
|
||||||
|
const uint i3 = gl_WorkGroupID.z;
|
||||||
|
|
||||||
|
if (global_thread_id >= nr || i2 >= n_t || i3 >= n_s) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint i1 = global_thread_id;
|
||||||
|
const uint src0_base = i3 * (nb02 / 4) + i2 + i1 * (nb01 / 4);
|
||||||
|
const uint src1_base = i1 * (nb11 / 4);
|
||||||
|
const uint dst_idx = i3 * (dst_nb2 / 4) + i2 * (dst_nb1 / 4) + i1;
|
||||||
|
|
||||||
|
float sum = 0.0;
|
||||||
|
[[unroll]] for (uint i0 = 0; i0 < nc; i0++) {
|
||||||
|
const uint src0_idx = src0_base + i0;
|
||||||
|
const uint src1_idx = src1_base + i0;
|
||||||
|
sum += src0[src0_idx] * src1[src1_idx];
|
||||||
|
}
|
||||||
|
|
||||||
|
dst[dst_idx] = sum;
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,140 @@
|
||||||
|
#version 450
|
||||||
|
|
||||||
|
#extension GL_EXT_control_flow_attributes : require
|
||||||
|
#if USE_SUBGROUP_ADD
|
||||||
|
#extension GL_KHR_shader_subgroup_arithmetic : enable
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "types.glsl"
|
||||||
|
|
||||||
|
layout(constant_id = 0) const uint D_STATE = 128;
|
||||||
|
layout(constant_id = 1) const uint SUBGROUP_SIZE = 32;
|
||||||
|
layout(constant_id = 2) const uint SPLIT_H = 16;
|
||||||
|
|
||||||
|
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
layout(binding = 0) readonly buffer Src0 { float s0[]; };
|
||||||
|
layout(binding = 1) readonly buffer Src1 { float x[]; };
|
||||||
|
layout(binding = 2) readonly buffer Src2 { float dt[]; };
|
||||||
|
layout(binding = 3) readonly buffer Src3 { float A[]; };
|
||||||
|
layout(binding = 4) readonly buffer Src4 { float B[]; };
|
||||||
|
layout(binding = 5) readonly buffer Src5 { float C[]; };
|
||||||
|
layout(binding = 6) readonly buffer Src6 { int ids[]; };
|
||||||
|
layout(binding = 7) buffer Dst { float d[]; };
|
||||||
|
|
||||||
|
layout(push_constant) uniform PushConstants {
|
||||||
|
uint nb02; uint nb03; uint nb12; uint nb13;
|
||||||
|
uint nb21; uint nb22; uint nb31;
|
||||||
|
uint nb42; uint nb43; uint nb52; uint nb53;
|
||||||
|
uint s_off;
|
||||||
|
uint n_head;
|
||||||
|
uint d_head;
|
||||||
|
uint n_group;
|
||||||
|
uint n_tok;
|
||||||
|
};
|
||||||
|
|
||||||
|
float softplus(float x) {
|
||||||
|
if (x <= 20.0) {
|
||||||
|
return log(1.0 + exp(x));
|
||||||
|
} else {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
shared float stateC[SPLIT_H * D_STATE];
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
const uint tid = gl_LocalInvocationID.x;
|
||||||
|
const uint head_idx = (gl_WorkGroupID.x * SPLIT_H) / d_head;
|
||||||
|
const uint head_off = ((gl_WorkGroupID.x * SPLIT_H) % d_head) * 4;
|
||||||
|
const uint seq_idx = gl_WorkGroupID.y;
|
||||||
|
|
||||||
|
const uint group_off = (head_idx / (n_head / n_group)) * D_STATE * 4;
|
||||||
|
const uint s0_base_idx = (uint(ids[seq_idx]) * nb03 + head_idx * nb02 + head_off * D_STATE) / 4;
|
||||||
|
const uint x_base_idx = (seq_idx * nb13 + gl_WorkGroupID.x * SPLIT_H * 4) / 4;
|
||||||
|
const uint dt_base_idx = (seq_idx * nb22 + head_idx * 4) / 4;
|
||||||
|
const uint A_base_idx = (head_idx * nb31) / 4;
|
||||||
|
const uint B_base_idx = (seq_idx * nb43 + group_off) / 4;
|
||||||
|
const uint C_base_idx = (seq_idx * nb53 + group_off) / 4;
|
||||||
|
const uint y_base_idx = seq_idx * n_tok * n_head * d_head + gl_WorkGroupID.x * SPLIT_H;
|
||||||
|
const uint s_base_idx = (s_off + seq_idx * nb03 + head_idx * nb02 + head_off * D_STATE) / 4;
|
||||||
|
|
||||||
|
const uint stride_x = nb12 / 4;
|
||||||
|
const uint stride_dt = nb21 / 4;
|
||||||
|
const uint stride_B = nb42 / 4;
|
||||||
|
const uint stride_C = nb52 / 4;
|
||||||
|
const uint stride_y = n_head * d_head;
|
||||||
|
|
||||||
|
float state[SPLIT_H];
|
||||||
|
[[unroll]] for (uint j = 0; j < SPLIT_H; j++) {
|
||||||
|
state[j] = s0[s0_base_idx + j * D_STATE + tid];
|
||||||
|
}
|
||||||
|
|
||||||
|
for (uint i = 0; i < n_tok; i++) {
|
||||||
|
const float dt_soft_plus = softplus(dt[dt_base_idx + i * stride_dt]);
|
||||||
|
|
||||||
|
const float dA = exp(dt_soft_plus * A[A_base_idx]);
|
||||||
|
|
||||||
|
const float B_val = B[B_base_idx + i * stride_B + tid];
|
||||||
|
const float C_val = C[C_base_idx + i * stride_C + tid];
|
||||||
|
|
||||||
|
[[unroll]] for (uint j = 0; j < SPLIT_H; j++) {
|
||||||
|
const float x_dt = x[x_base_idx + i * stride_x + j] * dt_soft_plus;
|
||||||
|
|
||||||
|
state[j] = (state[j] * dA) + (B_val * x_dt);
|
||||||
|
|
||||||
|
stateC[j * D_STATE + tid] = state[j] * C_val;
|
||||||
|
}
|
||||||
|
|
||||||
|
barrier();
|
||||||
|
[[unroll]]
|
||||||
|
for (uint w = D_STATE / 2; w >= SUBGROUP_SIZE; w >>= 1) {
|
||||||
|
[[unroll]] for (uint j = 0; j < (w * SPLIT_H + D_STATE - 1) / D_STATE; j++) {
|
||||||
|
const uint k = (tid % w) + (D_STATE * (tid / w)) + j * D_STATE * (D_STATE / w);
|
||||||
|
if (k < SPLIT_H * D_STATE && (k + w) < SPLIT_H * D_STATE) {
|
||||||
|
stateC[k] += stateC[k + w];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
|
}
|
||||||
|
|
||||||
|
[[unroll]] for (uint j = 0; j < max(1, SPLIT_H / (D_STATE / SUBGROUP_SIZE)); j++) {
|
||||||
|
const uint idx = (tid % SUBGROUP_SIZE) +
|
||||||
|
D_STATE * (tid / SUBGROUP_SIZE) +
|
||||||
|
j * D_STATE * (D_STATE / SUBGROUP_SIZE);
|
||||||
|
const uint max_idx = SUBGROUP_SIZE - 1 +
|
||||||
|
D_STATE * ((D_STATE - 1) / SUBGROUP_SIZE) +
|
||||||
|
j * D_STATE * (D_STATE / SUBGROUP_SIZE);
|
||||||
|
|
||||||
|
if (idx < SPLIT_H * D_STATE ||
|
||||||
|
max_idx < SPLIT_H * D_STATE) {
|
||||||
|
float sc;
|
||||||
|
#if USE_SUBGROUP_ADD
|
||||||
|
sc = stateC[idx];
|
||||||
|
sc = subgroupAdd(sc);
|
||||||
|
#else
|
||||||
|
[[unroll]] for (uint offset = SUBGROUP_SIZE / 2; offset > 0; offset >>= 1) {
|
||||||
|
if (idx + offset < SPLIT_H * D_STATE) {
|
||||||
|
stateC[idx] += stateC[idx + offset];
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
|
}
|
||||||
|
if (tid % SUBGROUP_SIZE == 0) {
|
||||||
|
sc = stateC[idx];
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
if (tid % SUBGROUP_SIZE == 0) {
|
||||||
|
const uint k = tid / SUBGROUP_SIZE + j * (D_STATE / SUBGROUP_SIZE);
|
||||||
|
d[y_base_idx + i * stride_y + k] = sc;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
barrier();
|
||||||
|
}
|
||||||
|
|
||||||
|
[[unroll]] for (uint j = 0; j < SPLIT_H; j++) {
|
||||||
|
d[s_base_idx + j * D_STATE + tid] = state[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,139 @@
|
||||||
|
#version 450
|
||||||
|
|
||||||
|
#extension GL_EXT_control_flow_attributes : require
|
||||||
|
#extension GL_KHR_shader_subgroup_basic : enable
|
||||||
|
#extension GL_KHR_shader_subgroup_arithmetic : enable
|
||||||
|
#extension GL_KHR_shader_subgroup_shuffle : enable
|
||||||
|
|
||||||
|
#include "types.glsl"
|
||||||
|
|
||||||
|
layout (push_constant) uniform parameter
|
||||||
|
{
|
||||||
|
uint n_rows;
|
||||||
|
uint n_expert_used;
|
||||||
|
};
|
||||||
|
|
||||||
|
layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
|
||||||
|
|
||||||
|
layout(constant_id = 0) const uint WARP_SIZE = 32;
|
||||||
|
layout(constant_id = 1) const uint n_experts = 512;
|
||||||
|
layout(constant_id = 2) const bool with_norm = true;
|
||||||
|
|
||||||
|
const uint experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1;
|
||||||
|
|
||||||
|
layout (binding = 0, std430) readonly buffer Logits {float logits[];};
|
||||||
|
layout (binding = 1, std430) writeonly buffer Weights {float weights[];};
|
||||||
|
layout (binding = 2, std430) writeonly buffer Ids {uint ids[];};
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_LocalInvocationID.y;
|
||||||
|
if (row >= n_rows) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint logits_offset = n_experts * row;
|
||||||
|
const uint weights_offset = n_expert_used * row;
|
||||||
|
const uint ids_offset = n_experts * row;
|
||||||
|
|
||||||
|
float logits_r[experts_per_thread];
|
||||||
|
|
||||||
|
const float INFINITY = 1.0 / 0.0;
|
||||||
|
|
||||||
|
[[unroll]]
|
||||||
|
for (uint i = 0; i < n_experts; i += WARP_SIZE) {
|
||||||
|
const uint expert = i + gl_LocalInvocationID.x;
|
||||||
|
logits_r[i / WARP_SIZE] = n_experts % WARP_SIZE == 0 || expert < n_experts ? logits[logits_offset + expert] : -INFINITY;
|
||||||
|
}
|
||||||
|
|
||||||
|
float max_val = logits_r[0];
|
||||||
|
|
||||||
|
[[unroll]]
|
||||||
|
for (int i = 1; i < experts_per_thread; i++) {
|
||||||
|
const float val = logits_r[i];
|
||||||
|
max_val = max(val, max_val);
|
||||||
|
}
|
||||||
|
|
||||||
|
max_val = subgroupMax(max_val);
|
||||||
|
|
||||||
|
float wt[experts_per_thread];
|
||||||
|
float tmp = 0.f;
|
||||||
|
|
||||||
|
[[unroll]]
|
||||||
|
for (int i = 0; i < experts_per_thread; i++) {
|
||||||
|
const float val = logits_r[i];
|
||||||
|
wt[i] = exp(val - max_val);
|
||||||
|
tmp += wt[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
tmp = subgroupAdd(tmp);
|
||||||
|
|
||||||
|
const float inv_sum = 1.0f / tmp;
|
||||||
|
|
||||||
|
[[unroll]]
|
||||||
|
for (int i = 0; i < experts_per_thread; i++) {
|
||||||
|
wt[i] = wt[i] * inv_sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
// at this point, each thread holds a portion of softmax,
|
||||||
|
// we do the argmax reduce over n_expert_used, each time marking
|
||||||
|
// the expert weight as -inf to exclude from the next iteration
|
||||||
|
|
||||||
|
float wt_sum = 0.f;
|
||||||
|
|
||||||
|
float output_weights[experts_per_thread];
|
||||||
|
|
||||||
|
for (int k = 0; k < n_expert_used; k++) {
|
||||||
|
float max_val = wt[0];
|
||||||
|
uint max_expert = gl_LocalInvocationID.x;
|
||||||
|
|
||||||
|
[[unroll]]
|
||||||
|
for (int i = 1; i < experts_per_thread; i++) {
|
||||||
|
const uint expert = gl_LocalInvocationID.x + i * WARP_SIZE;
|
||||||
|
if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) {
|
||||||
|
max_val = wt[i];
|
||||||
|
max_expert = expert;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[[unroll]]
|
||||||
|
for (uint mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
|
||||||
|
const float val = subgroupShuffleXor(max_val, mask);
|
||||||
|
const uint expert = subgroupShuffleXor(max_expert, mask);
|
||||||
|
if (val > max_val || (val == max_val && expert < max_expert)) {
|
||||||
|
max_val = val;
|
||||||
|
max_expert = expert;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((k & (WARP_SIZE - 1)) == gl_LocalInvocationID.x) {
|
||||||
|
output_weights[k / WARP_SIZE] = max_val;
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((max_expert & (WARP_SIZE - 1)) == gl_LocalInvocationID.x) {
|
||||||
|
wt[max_expert / WARP_SIZE] = -INFINITY;
|
||||||
|
|
||||||
|
ids[ids_offset + k] = max_expert;
|
||||||
|
if (with_norm) {
|
||||||
|
wt_sum += max_val;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (with_norm) {
|
||||||
|
wt_sum = subgroupAdd(wt_sum);
|
||||||
|
const float inv_sum = 1.0f / wt_sum;
|
||||||
|
|
||||||
|
[[unroll]]
|
||||||
|
for (uint i = 0; i < experts_per_thread; ++i) {
|
||||||
|
output_weights[i] *= inv_sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[[unroll]]
|
||||||
|
for (uint i = 0; i < experts_per_thread; ++i) {
|
||||||
|
uint idx = i * WARP_SIZE + gl_LocalInvocationID.x;
|
||||||
|
if (idx < n_expert_used) {
|
||||||
|
weights[weights_offset + idx] = output_weights[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -611,9 +611,6 @@ void process_shaders() {
|
||||||
}
|
}
|
||||||
|
|
||||||
for (const auto& tname : type_names) {
|
for (const auto& tname : type_names) {
|
||||||
if (tname == "f32") {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (tname == "bf16") continue;
|
if (tname == "bf16") continue;
|
||||||
|
|
||||||
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||||
|
|
@ -630,7 +627,7 @@ void process_shaders() {
|
||||||
if (tname == "f16") {
|
if (tname == "f16") {
|
||||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
|
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
|
||||||
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"COOPMAT", "1"}}), true, true, false, f16acc);
|
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"COOPMAT", "1"}}), true, true, false, f16acc);
|
||||||
} else if (tname == "q4_0" || tname == "q8_0") {
|
} else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") {
|
||||||
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
||||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
|
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
|
||||||
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc);
|
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc);
|
||||||
|
|
@ -639,7 +636,7 @@ void process_shaders() {
|
||||||
if (tname == "f16") {
|
if (tname == "f16") {
|
||||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
|
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
|
||||||
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, false, f16acc);
|
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, false, f16acc);
|
||||||
} else if (tname == "q4_0" || tname == "q8_0") {
|
} else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") {
|
||||||
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
||||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
|
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
|
||||||
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc);
|
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc);
|
||||||
|
|
@ -919,6 +916,13 @@ void process_shaders() {
|
||||||
string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "0"}});
|
string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "0"}});
|
||||||
string_to_spv("multi_add_rms_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "1"}});
|
string_to_spv("multi_add_rms_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "1"}});
|
||||||
|
|
||||||
|
string_to_spv("ssm_scan_f32", "ssm_scan.comp", {{"A_TYPE", "float"}});
|
||||||
|
string_to_spv("ssm_scan_subgroup_f32", "ssm_scan.comp", {{"A_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}});
|
||||||
|
|
||||||
|
string_to_spv("ssm_conv_f32", "ssm_conv.comp", {{"A_TYPE", "float"}});
|
||||||
|
|
||||||
|
string_to_spv("topk_moe_f32", "topk_moe.comp", {});
|
||||||
|
|
||||||
for (auto &c : compiles) {
|
for (auto &c : compiles) {
|
||||||
c.wait();
|
c.wait();
|
||||||
}
|
}
|
||||||
|
|
@ -962,7 +966,7 @@ void write_output_files() {
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string suffixes[2] = {"_f32", "_f16"};
|
std::string suffixes[2] = {"_f32", "_f16"};
|
||||||
for (auto op : {"add", "sub", "mul", "div", "add_rms"}) {
|
for (std::string op : {"add", "sub", "mul", "div", "add_rms"}) {
|
||||||
hdr << "extern const void * " << op << "_data[2][2][2][2];\n";
|
hdr << "extern const void * " << op << "_data[2][2][2][2];\n";
|
||||||
hdr << "extern const uint64_t " << op << "_len[2][2][2][2];\n";
|
hdr << "extern const uint64_t " << op << "_len[2][2][2][2];\n";
|
||||||
|
|
||||||
|
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue