From 7e3ea813c1d8a9714c6927f75656d5ff6eaf5acc Mon Sep 17 00:00:00 2001 From: Parth Sareen Date: Mon, 15 Dec 2025 18:00:08 -0800 Subject: [PATCH] llama/parsers/renderers: nemotron 3 nano (#13489) --------- Co-authored-by: Daniel Hiltgen --- llama/llama.cpp/src/llama-arch.cpp | 35 ++ llama/llama.cpp/src/llama-arch.h | 1 + llama/llama.cpp/src/llama-graph.cpp | 10 + llama/llama.cpp/src/llama-model.cpp | 50 +- llama/llama.cpp/src/llama-model.h | 1 + llama/llama.cpp/src/models/nemotron-h.cpp | 41 +- ...d-support-for-NVIDIA-Nemotron-Nano-3.patch | 586 ++++++++++++++++++ model/parsers/nemotron3nano.go | 255 ++++++++ model/parsers/nemotron3nano_test.go | 583 +++++++++++++++++ model/parsers/parsers.go | 4 + model/renderers/nemotron3nano.go | 222 +++++++ model/renderers/nemotron3nano_test.go | 585 +++++++++++++++++ model/renderers/renderer.go | 6 + 13 files changed, 2364 insertions(+), 15 deletions(-) create mode 100644 llama/patches/0032-llama-add-support-for-NVIDIA-Nemotron-Nano-3.patch create mode 100644 model/parsers/nemotron3nano.go create mode 100644 model/parsers/nemotron3nano_test.go create mode 100644 model/renderers/nemotron3nano.go create mode 100644 model/renderers/nemotron3nano_test.go diff --git a/llama/llama.cpp/src/llama-arch.cpp b/llama/llama.cpp/src/llama-arch.cpp index a5fe4f66c..ac8b5e033 100644 --- a/llama/llama.cpp/src/llama-arch.cpp +++ b/llama/llama.cpp/src/llama-arch.cpp @@ -75,6 +75,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_JAIS, "jais" }, { LLM_ARCH_NEMOTRON, "nemotron" }, { LLM_ARCH_NEMOTRON_H, "nemotron_h" }, + { LLM_ARCH_NEMOTRON_H_MOE, "nemotron_h_moe" }, { LLM_ARCH_EXAONE, "exaone" }, { LLM_ARCH_EXAONE4, "exaone4" }, { LLM_ARCH_RWKV6, "rwkv6" }, @@ -1765,6 +1766,39 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_NEMOTRON_H_MOE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + // mamba(2) ssm layers + { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" }, + { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" }, + { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" }, + { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" }, + { LLM_TENSOR_SSM_D, "blk.%d.ssm_d" }, + { LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" }, + { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, + // attention layers + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + // dense FFN + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + // MoE FFN (for MoE layers) + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_EXP_PROBS_B,"blk.%d.exp_probs_b" }, + // MoE shared expert layer + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + }, + }, { LLM_ARCH_EXAONE, { @@ -2838,6 +2872,7 @@ bool llm_arch_is_hybrid(const llm_arch & arch) { case LLM_ARCH_LFM2: case LLM_ARCH_LFM2MOE: case LLM_ARCH_NEMOTRON_H: + case LLM_ARCH_NEMOTRON_H_MOE: case LLM_ARCH_QWEN3NEXT: return true; default: diff --git a/llama/llama.cpp/src/llama-arch.h b/llama/llama.cpp/src/llama-arch.h index ec9e3a6df..61d73786c 100644 --- a/llama/llama.cpp/src/llama-arch.h +++ b/llama/llama.cpp/src/llama-arch.h @@ -79,6 +79,7 @@ enum llm_arch { LLM_ARCH_JAIS, LLM_ARCH_NEMOTRON, LLM_ARCH_NEMOTRON_H, + LLM_ARCH_NEMOTRON_H_MOE, LLM_ARCH_EXAONE, LLM_ARCH_EXAONE4, LLM_ARCH_RWKV6, diff --git a/llama/llama.cpp/src/llama-graph.cpp b/llama/llama.cpp/src/llama-graph.cpp index 43620df78..763202d87 100644 --- a/llama/llama.cpp/src/llama-graph.cpp +++ b/llama/llama.cpp/src/llama-graph.cpp @@ -1089,6 +1089,16 @@ ggml_tensor * llm_graph_context::build_moe_ffn( cur = ggml_relu(ctx0, cur); cb(cur, "ffn_moe_relu", il); } break; + case LLM_FFN_RELU_SQR: + if (gate_exps) { + // TODO: add support for gated squared relu + GGML_ABORT("fatal error: gated squared relu not implemented"); + } else { + cur = ggml_relu(ctx0, cur); + cur = ggml_sqr(ctx0, cur); + cb(cur, "ffn_moe_relu_sqr", il); + } + break; default: GGML_ABORT("fatal error"); } diff --git a/llama/llama.cpp/src/llama-model.cpp b/llama/llama.cpp/src/llama-model.cpp index 3c503b424..94dee78c3 100644 --- a/llama/llama.cpp/src/llama-model.cpp +++ b/llama/llama.cpp/src/llama-model.cpp @@ -120,6 +120,8 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_16B_A1B: return "16B.A1B"; case LLM_TYPE_21B_A3B: return "21B.A3B"; case LLM_TYPE_30B_A3B: return "30B.A3B"; + case LLM_TYPE_31B_A3_5B: return "31B.A3.5B"; + case LLM_TYPE_80B_A3B: return "80B.A3B"; case LLM_TYPE_100B_A6B: return "100B.A6B"; case LLM_TYPE_106B_A12B: return "106B.A12B"; case LLM_TYPE_230B_A10B: return "230B.A10B"; @@ -1788,6 +1790,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { } } break; case LLM_ARCH_NEMOTRON_H: + case LLM_ARCH_NEMOTRON_H_MOE: { ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); @@ -1803,7 +1806,14 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + switch (hparams.n_layer) { + case 52: type = LLM_TYPE_31B_A3_5B; break; // Nemotron-H_MOE 31B case 56: type = LLM_TYPE_9B; break; default: type = LLM_TYPE_UNKNOWN; } @@ -5175,6 +5185,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } } break; case LLM_ARCH_NEMOTRON_H: + case LLM_ARCH_NEMOTRON_H_MOE: { // mamba2 Mixer SSM params // NOTE: int64_t for tensor dimensions @@ -5185,6 +5196,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { const int64_t n_group = hparams.ssm_n_group; const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_ssm_head; + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + const int64_t n_ff_shexp = hparams.n_ff_shexp; + // embeddings tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -5234,12 +5248,26 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_k_gqa_i}, TENSOR_NOT_REQUIRED); layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_v_gqa_i}, TENSOR_NOT_REQUIRED); layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - } else { - // mlp layers - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { hparams.n_ff(i), n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, hparams.n_ff(i)}, 0); - layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {hparams.n_ff(i)}, TENSOR_NOT_REQUIRED); + } else { + if (n_expert != 0) { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert}, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert }, 0); + + // MoE branch + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + + // Shared expert branch + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0); + + } else { + // mlp layers + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { hparams.n_ff(i), n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, hparams.n_ff(i)}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {hparams.n_ff(i)}, TENSOR_NOT_REQUIRED); + } } } } break; @@ -6870,7 +6898,8 @@ void llama_model::print_info() const { arch == LLM_ARCH_PLAMO2 || arch == LLM_ARCH_GRANITE_HYBRID || arch == LLM_ARCH_QWEN3NEXT || - arch == LLM_ARCH_NEMOTRON_H) { + arch == LLM_ARCH_NEMOTRON_H || + arch == LLM_ARCH_NEMOTRON_H_MOE) { LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); @@ -6926,7 +6955,8 @@ void llama_model::print_info() const { if (arch == LLM_ARCH_MINICPM || arch == LLM_ARCH_GRANITE || arch == LLM_ARCH_GRANITE_MOE || - arch == LLM_ARCH_GRANITE_HYBRID) { + arch == LLM_ARCH_GRANITE_HYBRID || + arch == LLM_ARCH_NEMOTRON_H_MOE) { LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale); LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale); LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale); @@ -7107,7 +7137,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, if (arch == LLM_ARCH_FALCON_H1) { filter_attn = [&](int32_t) { return true; }; filter_recr = [&](int32_t) { return true; }; - } else if (arch == LLM_ARCH_NEMOTRON_H) { + } else if (arch == LLM_ARCH_NEMOTRON_H || arch == LLM_ARCH_NEMOTRON_H_MOE) { filter_attn = [&](int32_t il) { return !hparams.is_recurrent(il) && hparams.n_ff(il) == 0; }; @@ -7478,6 +7508,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { llm = std::make_unique(*this, params); } break; case LLM_ARCH_NEMOTRON_H: + case LLM_ARCH_NEMOTRON_H_MOE: { llm = std::make_unique(*this, params); } break; @@ -7765,6 +7796,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_ARWKV7: case LLM_ARCH_WAVTOKENIZER_DEC: case LLM_ARCH_NEMOTRON_H: + case LLM_ARCH_NEMOTRON_H_MOE: return LLAMA_ROPE_TYPE_NONE; // use what we call a normal RoPE, operating on pairs of consecutive head values diff --git a/llama/llama.cpp/src/llama-model.h b/llama/llama.cpp/src/llama-model.h index cbf4e1bfa..b378b23ec 100644 --- a/llama/llama.cpp/src/llama-model.h +++ b/llama/llama.cpp/src/llama-model.h @@ -114,6 +114,7 @@ enum llm_type { LLM_TYPE_16B_A1B, LLM_TYPE_21B_A3B, // Ernie MoE small LLM_TYPE_30B_A3B, + LLM_TYPE_31B_A3_5B, LLM_TYPE_80B_A3B, // Qwen3 Next LLM_TYPE_100B_A6B, LLM_TYPE_106B_A12B, // GLM-4.5-Air diff --git a/llama/llama.cpp/src/models/nemotron-h.cpp b/llama/llama.cpp/src/models/nemotron-h.cpp index 541434888..eb135e63f 100644 --- a/llama/llama.cpp/src/models/nemotron-h.cpp +++ b/llama/llama.cpp/src/models/nemotron-h.cpp @@ -107,12 +107,41 @@ ggml_tensor * llm_build_nemotron_h::build_attention_layer(ggml_tensor * } ggml_tensor * llm_build_nemotron_h::build_ffn_layer(ggml_tensor * cur, const llama_model & model, const int il) { - cur = build_ffn(cur, - model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, - NULL, NULL, NULL, - model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, - NULL, LLM_FFN_RELU_SQR, LLM_FFN_PAR, il); - cb(cur, "ffn_out", il); + if (model.layers[il].ffn_gate_inp == nullptr) { + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_RELU_SQR, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } else { + ggml_tensor * ffn_inp = cur; + ggml_tensor * moe_out = + build_moe_ffn(ffn_inp, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + nullptr, // no gate + model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_RELU_SQR, hparams.expert_weights_norm, + true, hparams.expert_weights_scale, + LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID, + il); + cb(moe_out, "ffn_moe_out", il); + + ggml_tensor * ffn_shexp = build_ffn(ffn_inp, + model.layers[il].ffn_up_shexp, NULL, NULL, + NULL /* no gate */ , NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_RELU_SQR, LLM_FFN_PAR, il); + cb(ffn_shexp, "ffn_shexp", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + cb(cur, "ffn_out", il); + } cur = build_cvec(cur, il); cb(cur, "l_out", il); diff --git a/llama/patches/0032-llama-add-support-for-NVIDIA-Nemotron-Nano-3.patch b/llama/patches/0032-llama-add-support-for-NVIDIA-Nemotron-Nano-3.patch new file mode 100644 index 000000000..00536c4b8 --- /dev/null +++ b/llama/patches/0032-llama-add-support-for-NVIDIA-Nemotron-Nano-3.patch @@ -0,0 +1,586 @@ +From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 +From: Daniel Bevenius +Date: Mon, 15 Dec 2025 15:13:49 +0100 +Subject: [PATCH] llama : add support for NVIDIA Nemotron Nano 3 + +This commit adds support for the NVIDIA Nemotron Nano 3 model, enabling +the conversion and running of this model. + +fix indentation in llama-graph.cpp + +fix indentation and move ffn_inp + +convert : fix modify_tensors in NemotronHModel to call super() + +fix pyright error + +fix flake8 errors +--- + convert_hf_to_gguf.py | 116 +++++++++++++++++++++++++++++++-- + gguf-py/gguf/constants.py | 29 +++++++++ + gguf-py/gguf/tensor_mapping.py | 9 ++- + src/llama-arch.cpp | 35 ++++++++++ + src/llama-arch.h | 1 + + src/llama-graph.cpp | 10 +++ + src/llama-model.cpp | 50 +++++++++++--- + src/llama-model.h | 1 + + src/models/nemotron-h.cpp | 41 ++++++++++-- + 9 files changed, 269 insertions(+), 23 deletions(-) + +diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py +index 867bc9053..57ec2faac 100755 +--- a/convert_hf_to_gguf.py ++++ b/convert_hf_to_gguf.py +@@ -8601,8 +8601,18 @@ class GraniteHybridModel(Mamba2Model, GraniteMoeModel): + class NemotronHModel(GraniteHybridModel): + """Hybrid mamba2/attention model from NVIDIA""" + model_arch = gguf.MODEL_ARCH.NEMOTRON_H ++ is_moe: bool = False + + def __init__(self, *args, **kwargs): ++ # We have to determine the correct model architecture (MoE vs non-MoE) before ++ # calling the parent __init__. This is because the parent constructor ++ # uses self.model_arch to build the tensor name map, and all MoE-specific ++ # mappings would be missed if it were called with the default non-MoE arch. ++ hparams = ModelBase.load_hparams(args[0], self.is_mistral_format) ++ if "num_experts_per_tok" in hparams: ++ self.model_arch = gguf.MODEL_ARCH.NEMOTRON_H_MOE ++ self.is_moe = True ++ + super().__init__(*args, **kwargs) + + # Save the top-level head_dim for later +@@ -8614,9 +8624,11 @@ class NemotronHModel(GraniteHybridModel): + + # Update the ssm / attn / mlp layers + # M: Mamba2, *: Attention, -: MLP ++ # MoE: ++ # M: Mamba2, *: Attention, E: Expert + hybrid_override_pattern = self.hparams["hybrid_override_pattern"] + self._ssm_layers = [i for i, val in enumerate(hybrid_override_pattern) if val == "M"] +- self._mlp_layers = [i for i, val in enumerate(hybrid_override_pattern) if val == "-"] ++ self._mlp_layers = [i for i, val in enumerate(hybrid_override_pattern) if val == ("E" if self.is_moe else "-")] + + def get_attn_layers(self): + hybrid_override_pattern = self.hparams["hybrid_override_pattern"] +@@ -8632,10 +8644,28 @@ class NemotronHModel(GraniteHybridModel): + # Set feed_forward_length + # NOTE: This will trigger an override warning. This is preferrable to + # duplicating all the parent logic +- n_ff = self.find_hparam(["intermediate_size", "n_inner", "hidden_dim"]) +- self.gguf_writer.add_feed_forward_length([ +- n_ff if i in self._mlp_layers else 0 for i in range(self.block_count) +- ]) ++ if not self.is_moe: ++ n_ff = self.find_hparam(["intermediate_size", "n_inner", "hidden_dim"]) ++ self.gguf_writer.add_feed_forward_length([ ++ n_ff if i in self._mlp_layers else 0 for i in range(self.block_count) ++ ]) ++ else: ++ moe_intermediate_size = self.hparams["moe_intermediate_size"] ++ self.gguf_writer.add_feed_forward_length([ ++ moe_intermediate_size if i in self._mlp_layers else 0 for i in range(self.block_count) ++ ]) ++ self.gguf_writer.add_expert_used_count(self.hparams["num_experts_per_tok"]) ++ self.gguf_writer.add_expert_feed_forward_length(self.hparams["moe_intermediate_size"]) ++ self.gguf_writer.add_expert_shared_feed_forward_length(self.hparams["moe_shared_expert_intermediate_size"]) ++ self.gguf_writer.add_expert_count(self.hparams["n_routed_experts"]) ++ self.gguf_writer.add_expert_shared_count(self.hparams["n_shared_experts"]) ++ self.gguf_writer.add_expert_weights_norm(self.hparams["norm_topk_prob"]) ++ self.gguf_writer.add_expert_weights_scale(self.hparams["routed_scaling_factor"]) ++ self.gguf_writer.add_expert_group_count(self.hparams["n_group"]) ++ ++ # number of experts used per token (top-k) ++ if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None: ++ self.gguf_writer.add_expert_used_count(n_experts_used) + + def set_vocab(self): + super().set_vocab() +@@ -8643,7 +8673,81 @@ class NemotronHModel(GraniteHybridModel): + # The tokenizer _does_ add a BOS token (via post_processor type + # TemplateProcessing) but does not set add_bos_token to true in the + # config, so we need to explicitly override it here. +- self.gguf_writer.add_add_bos_token(True) ++ if not self.is_moe: ++ self.gguf_writer.add_add_bos_token(True) ++ ++ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: ++ if self.is_moe and bid is not None: ++ if name.endswith("mixer.gate.e_score_correction_bias"): ++ new_name = name.replace("e_score_correction_bias", "e_score_correction_bias.bias") ++ mapped_name = self.map_tensor_name(new_name) ++ return [(mapped_name, data_torch)] ++ ++ if name.endswith("mixer.dt_bias"): ++ new_name = name.replace("dt_bias", "dt.bias") ++ mapped_name = self.map_tensor_name(new_name) ++ return [(mapped_name, data_torch)] ++ ++ if name.endswith("mixer.conv1d.weight"): ++ squeezed_data = data_torch.squeeze() ++ mapped_name = self.map_tensor_name(name) ++ return [(mapped_name, squeezed_data)] ++ ++ if name.endswith("mixer.A_log"): ++ transformed_data = -torch.exp(data_torch) ++ reshaped_data = transformed_data.squeeze().reshape(-1, 1) ++ mapped_name = self.map_tensor_name(name) ++ return [(mapped_name, reshaped_data)] ++ ++ if name.endswith("mixer.D"): ++ reshaped_data = data_torch.squeeze().reshape(-1, 1) ++ mapped_name = self.map_tensor_name(name) ++ return [(mapped_name, reshaped_data)] ++ ++ if name.endswith("mixer.norm.weight"): ++ reshaped_data = data_torch.reshape(8, 512) ++ mapped_name = self.map_tensor_name(name) ++ return [(mapped_name, reshaped_data)] ++ ++ if name.find("mixer.experts") != -1: ++ n_experts = self.hparams["n_routed_experts"] ++ assert bid is not None ++ ++ if self._experts is None: ++ self._experts = [{} for _ in range(self.block_count)] ++ ++ self._experts[bid][name] = data_torch ++ ++ if len(self._experts[bid]) >= n_experts * 2: ++ # merge the experts into a single tensor ++ tensors: list[tuple[str, Tensor]] = [] ++ for w_name in ["down_proj", "up_proj"]: ++ datas: list[Tensor] = [] ++ ++ for xid in range(n_experts): ++ ename = f"backbone.layers.{bid}.mixer.experts.{xid}.{w_name}.weight" ++ datas.append(self._experts[bid][ename]) ++ del self._experts[bid][ename] ++ ++ data_torch = torch.stack(datas, dim=0) ++ merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight" ++ new_name = self.map_tensor_name(merged_name) ++ tensors.append((new_name, data_torch)) ++ ++ return tensors ++ else: ++ return [] ++ ++ return super().modify_tensors(data_torch, name, bid) ++ ++ def prepare_tensors(self): ++ super().prepare_tensors() ++ ++ if self._experts is not None: ++ # flatten `list[dict[str, Tensor]]` into `list[str]` ++ experts = [k for d in self._experts for k in d.keys()] ++ if len(experts) > 0: ++ raise ValueError(f"Unprocessed experts: {experts}") + + + @ModelBase.register("BailingMoeForCausalLM") +diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py +index 2b8489c59..1852428b4 100644 +--- a/gguf-py/gguf/constants.py ++++ b/gguf-py/gguf/constants.py +@@ -413,6 +413,7 @@ class MODEL_ARCH(IntEnum): + JAIS = auto() + NEMOTRON = auto() + NEMOTRON_H = auto() ++ NEMOTRON_H_MOE = auto() + EXAONE = auto() + EXAONE4 = auto() + GRANITE = auto() +@@ -786,6 +787,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { + MODEL_ARCH.JAIS: "jais", + MODEL_ARCH.NEMOTRON: "nemotron", + MODEL_ARCH.NEMOTRON_H: "nemotron_h", ++ MODEL_ARCH.NEMOTRON_H_MOE: "nemotron_h_moe", + MODEL_ARCH.EXAONE: "exaone", + MODEL_ARCH.EXAONE4: "exaone4", + MODEL_ARCH.GRANITE: "granite", +@@ -2529,6 +2531,33 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], ++ MODEL_ARCH.NEMOTRON_H_MOE: [ ++ MODEL_TENSOR.TOKEN_EMBD, ++ MODEL_TENSOR.OUTPUT_NORM, ++ MODEL_TENSOR.OUTPUT, ++ MODEL_TENSOR.ATTN_NORM, ++ MODEL_TENSOR.SSM_IN, ++ MODEL_TENSOR.SSM_CONV1D, ++ MODEL_TENSOR.SSM_DT, ++ MODEL_TENSOR.SSM_A, ++ MODEL_TENSOR.SSM_D, ++ MODEL_TENSOR.SSM_NORM, ++ MODEL_TENSOR.SSM_OUT, ++ MODEL_TENSOR.ATTN_Q, ++ MODEL_TENSOR.ATTN_K, ++ MODEL_TENSOR.ATTN_V, ++ MODEL_TENSOR.ATTN_OUT, ++ MODEL_TENSOR.FFN_DOWN, ++ MODEL_TENSOR.FFN_UP, ++ # experts ++ MODEL_TENSOR.FFN_GATE_INP, ++ MODEL_TENSOR.FFN_UP_EXP, ++ MODEL_TENSOR.FFN_DOWN_EXP, ++ # shared expert ++ MODEL_TENSOR.FFN_DOWN_SHEXP, ++ MODEL_TENSOR.FFN_UP_SHEXP, ++ MODEL_TENSOR.FFN_EXP_PROBS_B, ++ ], + MODEL_ARCH.EXAONE: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, +diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py +index d9c87da19..7a3c7c5e0 100644 +--- a/gguf-py/gguf/tensor_mapping.py ++++ b/gguf-py/gguf/tensor_mapping.py +@@ -377,6 +377,7 @@ class TensorNameMap: + "model.layers.{bid}.feed_forward.gate", # lfm2moe + "model.layers.{bid}.mlp.router.gate", # afmoe + "layers.{bid}.gate", # mistral-large ++ "backbone.layers.{bid}.mixer.gate", # nemotron-h-moe + ), + + MODEL_TENSOR.FFN_GATE_INP_SHEXP: ( +@@ -390,6 +391,7 @@ class TensorNameMap: + "model.layers.{bid}.mlp.expert_bias", # afmoe + "model.layers.{bid}.feed_forward.expert_bias", # lfm2moe + "model.layers.{bid}.block_sparse_moe.e_score_correction", # minimax-m2 ++ "backbone.layers.{bid}.mixer.gate.e_score_correction_bias" # nemotron-h-moe + ), + + # Feed-forward up +@@ -438,7 +440,7 @@ class TensorNameMap: + "layers.{bid}.feed_forward.experts.w3", # mixtral (merged) + "transformer.decoder_layer.{bid}.moe.linear_v", # Grok (merged) + "transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx +- "model.layers.{bid}.mlp.experts.up_proj", # qwen2moe olmoe (merged) ernie4.5-moe ++ "model.layers.{bid}.mlp.experts.up_proj", # qwen2moe olmoe (merged) ernie4.5-moe, nemotron-h-moe (merged) + "model.layers.{bid}.block_sparse_moe.experts.w3", # phimoe (merged) + "model.layers.{bid}.feed_forward.experts.up_proj", # llama4 + "encoder.layers.{bid}.mlp.experts.mlp.w1", # nomic-bert-moe +@@ -452,6 +454,7 @@ class TensorNameMap: + "model.layers.{bid}.feed_forward.down_proj", + "model.layers.{bid}.mlp.shared_mlp.up_proj", # hunyuan + "layers.{bid}.shared_experts.w3", # mistral-large ++ "backbone.layers.{bid}.mixer.shared_experts.up_proj", # nemotron-h-moe + ), + + MODEL_TENSOR.FFN_UP_CHEXP: ( +@@ -546,7 +549,7 @@ class TensorNameMap: + "layers.{bid}.feed_forward.experts.w2", # mixtral (merged) + "transformer.decoder_layer.{bid}.moe.linear_1", # Grok (merged) + "transformer.blocks.{bid}.ffn.experts.mlp.w2", # dbrx +- "model.layers.{bid}.mlp.experts.down_proj", # qwen2moe olmoe (merged) ernie4.5-moe ++ "model.layers.{bid}.mlp.experts.down_proj", # qwen2moe olmoe (merged) ernie4.5-moe nemotron-h-moe (merged) + "model.layers.{bid}.block_sparse_moe.output_linear", # granitemoe + "model.layers.{bid}.block_sparse_moe.experts.w2", # phimoe (merged) + "model.layers.{bid}.feed_forward.experts.down_proj", # llama4 +@@ -561,6 +564,7 @@ class TensorNameMap: + "model.layers.{bid}.shared_mlp.output_linear", # granitemoe + "model.layers.{bid}.mlp.shared_mlp.down_proj", # hunyuan + "layers.{bid}.shared_experts.w2", # mistral-large ++ "backbone.layers.{bid}.mixer.shared_experts.down_proj", # nemotron-h-moe + ), + + MODEL_TENSOR.FFN_DOWN_CHEXP: ( +@@ -704,6 +708,7 @@ class TensorNameMap: + "model.layers.{bid}.mamba.dt_proj", # jamba falcon-h1 granite-hybrid + "model.layers.layers.{bid}.mixer.dt_proj", # plamo2 + "model.layers.{bid}.linear_attn.dt_proj", # qwen3next ++ "backbone.layers.{bid}.mixer.dt", # nemotron-h-moe + ), + + MODEL_TENSOR.SSM_DT_NORM: ( +diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp +index a5fe4f66c..ac8b5e033 100644 +--- a/src/llama-arch.cpp ++++ b/src/llama-arch.cpp +@@ -75,6 +75,7 @@ static const std::map LLM_ARCH_NAMES = { + { LLM_ARCH_JAIS, "jais" }, + { LLM_ARCH_NEMOTRON, "nemotron" }, + { LLM_ARCH_NEMOTRON_H, "nemotron_h" }, ++ { LLM_ARCH_NEMOTRON_H_MOE, "nemotron_h_moe" }, + { LLM_ARCH_EXAONE, "exaone" }, + { LLM_ARCH_EXAONE4, "exaone4" }, + { LLM_ARCH_RWKV6, "rwkv6" }, +@@ -1765,6 +1766,39 @@ static const std::map> LLM_TENSOR_N + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, ++ { ++ LLM_ARCH_NEMOTRON_H_MOE, ++ { ++ { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, ++ { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, ++ { LLM_TENSOR_OUTPUT, "output" }, ++ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, ++ // mamba(2) ssm layers ++ { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" }, ++ { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" }, ++ { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" }, ++ { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" }, ++ { LLM_TENSOR_SSM_D, "blk.%d.ssm_d" }, ++ { LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" }, ++ { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, ++ // attention layers ++ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, ++ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, ++ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, ++ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, ++ // dense FFN ++ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, ++ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, ++ // MoE FFN (for MoE layers) ++ { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, ++ { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, ++ { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, ++ { LLM_TENSOR_FFN_EXP_PROBS_B,"blk.%d.exp_probs_b" }, ++ // MoE shared expert layer ++ { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, ++ { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, ++ }, ++ }, + { + LLM_ARCH_EXAONE, + { +@@ -2838,6 +2872,7 @@ bool llm_arch_is_hybrid(const llm_arch & arch) { + case LLM_ARCH_LFM2: + case LLM_ARCH_LFM2MOE: + case LLM_ARCH_NEMOTRON_H: ++ case LLM_ARCH_NEMOTRON_H_MOE: + case LLM_ARCH_QWEN3NEXT: + return true; + default: +diff --git a/src/llama-arch.h b/src/llama-arch.h +index ec9e3a6df..61d73786c 100644 +--- a/src/llama-arch.h ++++ b/src/llama-arch.h +@@ -79,6 +79,7 @@ enum llm_arch { + LLM_ARCH_JAIS, + LLM_ARCH_NEMOTRON, + LLM_ARCH_NEMOTRON_H, ++ LLM_ARCH_NEMOTRON_H_MOE, + LLM_ARCH_EXAONE, + LLM_ARCH_EXAONE4, + LLM_ARCH_RWKV6, +diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp +index 43620df78..763202d87 100644 +--- a/src/llama-graph.cpp ++++ b/src/llama-graph.cpp +@@ -1089,6 +1089,16 @@ ggml_tensor * llm_graph_context::build_moe_ffn( + cur = ggml_relu(ctx0, cur); + cb(cur, "ffn_moe_relu", il); + } break; ++ case LLM_FFN_RELU_SQR: ++ if (gate_exps) { ++ // TODO: add support for gated squared relu ++ GGML_ABORT("fatal error: gated squared relu not implemented"); ++ } else { ++ cur = ggml_relu(ctx0, cur); ++ cur = ggml_sqr(ctx0, cur); ++ cb(cur, "ffn_moe_relu_sqr", il); ++ } ++ break; + default: + GGML_ABORT("fatal error"); + } +diff --git a/src/llama-model.cpp b/src/llama-model.cpp +index 3c503b424..94dee78c3 100644 +--- a/src/llama-model.cpp ++++ b/src/llama-model.cpp +@@ -120,6 +120,8 @@ const char * llm_type_name(llm_type type) { + case LLM_TYPE_16B_A1B: return "16B.A1B"; + case LLM_TYPE_21B_A3B: return "21B.A3B"; + case LLM_TYPE_30B_A3B: return "30B.A3B"; ++ case LLM_TYPE_31B_A3_5B: return "31B.A3.5B"; ++ case LLM_TYPE_80B_A3B: return "80B.A3B"; + case LLM_TYPE_100B_A6B: return "100B.A6B"; + case LLM_TYPE_106B_A12B: return "106B.A12B"; + case LLM_TYPE_230B_A10B: return "230B.A10B"; +@@ -1788,6 +1790,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { + } + } break; + case LLM_ARCH_NEMOTRON_H: ++ case LLM_ARCH_NEMOTRON_H_MOE: + { + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); +@@ -1803,7 +1806,14 @@ void llama_model::load_hparams(llama_model_loader & ml) { + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ++ ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); ++ ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); ++ ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared, false); ++ ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); ++ ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); ++ + switch (hparams.n_layer) { ++ case 52: type = LLM_TYPE_31B_A3_5B; break; // Nemotron-H_MOE 31B + case 56: type = LLM_TYPE_9B; break; + default: type = LLM_TYPE_UNKNOWN; + } +@@ -5175,6 +5185,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { + } + } break; + case LLM_ARCH_NEMOTRON_H: ++ case LLM_ARCH_NEMOTRON_H_MOE: + { + // mamba2 Mixer SSM params + // NOTE: int64_t for tensor dimensions +@@ -5185,6 +5196,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { + const int64_t n_group = hparams.ssm_n_group; + const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_ssm_head; + ++ const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; ++ const int64_t n_ff_shexp = hparams.n_ff_shexp; ++ + // embeddings + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + +@@ -5234,12 +5248,26 @@ bool llama_model::load_tensors(llama_model_loader & ml) { + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_k_gqa_i}, TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_v_gqa_i}, TENSOR_NOT_REQUIRED); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); +- } else { +- // mlp layers +- layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { hparams.n_ff(i), n_embd}, 0); +- layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, hparams.n_ff(i)}, 0); +- layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); +- layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {hparams.n_ff(i)}, TENSOR_NOT_REQUIRED); ++ } else { ++ if (n_expert != 0) { ++ layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert}, 0); ++ layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert }, 0); ++ ++ // MoE branch ++ layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); ++ layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); ++ ++ // Shared expert branch ++ layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0); ++ layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0); ++ ++ } else { ++ // mlp layers ++ layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { hparams.n_ff(i), n_embd}, 0); ++ layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, hparams.n_ff(i)}, 0); ++ layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); ++ layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {hparams.n_ff(i)}, TENSOR_NOT_REQUIRED); ++ } + } + } + } break; +@@ -6870,7 +6898,8 @@ void llama_model::print_info() const { + arch == LLM_ARCH_PLAMO2 || + arch == LLM_ARCH_GRANITE_HYBRID || + arch == LLM_ARCH_QWEN3NEXT || +- arch == LLM_ARCH_NEMOTRON_H) { ++ arch == LLM_ARCH_NEMOTRON_H || ++ arch == LLM_ARCH_NEMOTRON_H_MOE) { + LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); + LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); + LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); +@@ -6926,7 +6955,8 @@ void llama_model::print_info() const { + if (arch == LLM_ARCH_MINICPM || + arch == LLM_ARCH_GRANITE || + arch == LLM_ARCH_GRANITE_MOE || +- arch == LLM_ARCH_GRANITE_HYBRID) { ++ arch == LLM_ARCH_GRANITE_HYBRID || ++ arch == LLM_ARCH_NEMOTRON_H_MOE) { + LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale); + LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale); + LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale); +@@ -7107,7 +7137,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, + if (arch == LLM_ARCH_FALCON_H1) { + filter_attn = [&](int32_t) { return true; }; + filter_recr = [&](int32_t) { return true; }; +- } else if (arch == LLM_ARCH_NEMOTRON_H) { ++ } else if (arch == LLM_ARCH_NEMOTRON_H || arch == LLM_ARCH_NEMOTRON_H_MOE) { + filter_attn = [&](int32_t il) { + return !hparams.is_recurrent(il) && hparams.n_ff(il) == 0; + }; +@@ -7478,6 +7508,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { + llm = std::make_unique(*this, params); + } break; + case LLM_ARCH_NEMOTRON_H: ++ case LLM_ARCH_NEMOTRON_H_MOE: + { + llm = std::make_unique(*this, params); + } break; +@@ -7765,6 +7796,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { + case LLM_ARCH_ARWKV7: + case LLM_ARCH_WAVTOKENIZER_DEC: + case LLM_ARCH_NEMOTRON_H: ++ case LLM_ARCH_NEMOTRON_H_MOE: + return LLAMA_ROPE_TYPE_NONE; + + // use what we call a normal RoPE, operating on pairs of consecutive head values +diff --git a/src/llama-model.h b/src/llama-model.h +index cbf4e1bfa..b378b23ec 100644 +--- a/src/llama-model.h ++++ b/src/llama-model.h +@@ -114,6 +114,7 @@ enum llm_type { + LLM_TYPE_16B_A1B, + LLM_TYPE_21B_A3B, // Ernie MoE small + LLM_TYPE_30B_A3B, ++ LLM_TYPE_31B_A3_5B, + LLM_TYPE_80B_A3B, // Qwen3 Next + LLM_TYPE_100B_A6B, + LLM_TYPE_106B_A12B, // GLM-4.5-Air +diff --git a/src/models/nemotron-h.cpp b/src/models/nemotron-h.cpp +index 541434888..eb135e63f 100644 +--- a/src/models/nemotron-h.cpp ++++ b/src/models/nemotron-h.cpp +@@ -107,12 +107,41 @@ ggml_tensor * llm_build_nemotron_h::build_attention_layer(ggml_tensor * + } + + ggml_tensor * llm_build_nemotron_h::build_ffn_layer(ggml_tensor * cur, const llama_model & model, const int il) { +- cur = build_ffn(cur, +- model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, +- NULL, NULL, NULL, +- model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, +- NULL, LLM_FFN_RELU_SQR, LLM_FFN_PAR, il); +- cb(cur, "ffn_out", il); ++ if (model.layers[il].ffn_gate_inp == nullptr) { ++ cur = build_ffn(cur, ++ model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, ++ NULL, NULL, NULL, ++ model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, ++ NULL, ++ LLM_FFN_RELU_SQR, LLM_FFN_PAR, il); ++ cb(cur, "ffn_out", il); ++ } else { ++ ggml_tensor * ffn_inp = cur; ++ ggml_tensor * moe_out = ++ build_moe_ffn(ffn_inp, ++ model.layers[il].ffn_gate_inp, ++ model.layers[il].ffn_up_exps, ++ nullptr, // no gate ++ model.layers[il].ffn_down_exps, ++ model.layers[il].ffn_exp_probs_b, ++ n_expert, n_expert_used, ++ LLM_FFN_RELU_SQR, hparams.expert_weights_norm, ++ true, hparams.expert_weights_scale, ++ LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID, ++ il); ++ cb(moe_out, "ffn_moe_out", il); ++ ++ ggml_tensor * ffn_shexp = build_ffn(ffn_inp, ++ model.layers[il].ffn_up_shexp, NULL, NULL, ++ NULL /* no gate */ , NULL, NULL, ++ model.layers[il].ffn_down_shexp, NULL, NULL, ++ NULL, ++ LLM_FFN_RELU_SQR, LLM_FFN_PAR, il); ++ cb(ffn_shexp, "ffn_shexp", il); ++ ++ cur = ggml_add(ctx0, moe_out, ffn_shexp); ++ cb(cur, "ffn_out", il); ++ } + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); diff --git a/model/parsers/nemotron3nano.go b/model/parsers/nemotron3nano.go new file mode 100644 index 000000000..c230b4e7d --- /dev/null +++ b/model/parsers/nemotron3nano.go @@ -0,0 +1,255 @@ +package parsers + +import ( + "regexp" + "strings" + "unicode" + + "github.com/ollama/ollama/api" +) + +type Nemotron3NanoParserState int + +const ( + Nemotron3NanoCollectingThinking Nemotron3NanoParserState = iota + Nemotron3NanoSkipWhitespaceAfterThinking + Nemotron3NanoCollectingContent + Nemotron3NanoCollectingToolCalls +) + +const ( + nemotronThinkClose = "" + nemotronToolCallOpen = "" + nemotronToolCallClose = "" +) + +type Nemotron3NanoParser struct { + state Nemotron3NanoParserState + buffer strings.Builder + tools []api.Tool + HasThinking bool +} + +func (p *Nemotron3NanoParser) HasToolSupport() bool { return true } +func (p *Nemotron3NanoParser) HasThinkingSupport() bool { return p.HasThinking } + +func (p *Nemotron3NanoParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool { + p.tools = tools + + // Check both model capability AND request preference + thinkingEnabled := thinkValue != nil && thinkValue.Bool() + + prefill := lastMessage != nil && lastMessage.Role == "assistant" + + if !thinkingEnabled { + p.state = Nemotron3NanoCollectingContent + return tools + } + + if prefill && lastMessage.Content != "" { + p.state = Nemotron3NanoCollectingContent + return tools + } + + p.state = Nemotron3NanoCollectingThinking + return tools +} + +type nemotronEvent interface { + isNemotronEvent() +} + +type nemotronEventThinkingContent struct { + content string +} + +type nemotronEventContent struct { + content string +} + +type nemotronEventToolCall struct { + toolCall api.ToolCall +} + +func (nemotronEventThinkingContent) isNemotronEvent() {} +func (nemotronEventContent) isNemotronEvent() {} +func (nemotronEventToolCall) isNemotronEvent() {} + +func (p *Nemotron3NanoParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) { + p.buffer.WriteString(s) + events := p.parseEvents() + + var toolCalls []api.ToolCall + var contentSb strings.Builder + var thinkingSb strings.Builder + for _, event := range events { + switch event := event.(type) { + case nemotronEventToolCall: + toolCalls = append(toolCalls, event.toolCall) + case nemotronEventThinkingContent: + thinkingSb.WriteString(event.content) + case nemotronEventContent: + contentSb.WriteString(event.content) + } + } + + return contentSb.String(), thinkingSb.String(), toolCalls, nil +} + +func (p *Nemotron3NanoParser) parseEvents() []nemotronEvent { + var all []nemotronEvent + + keepLooping := true + for keepLooping { + var events []nemotronEvent + events, keepLooping = p.eat() + if len(events) > 0 { + all = append(all, events...) + } + } + + return all +} + +// emitWithPartialCheck extracts unambiguous content before a potential partial tag +func (p *Nemotron3NanoParser) emitWithPartialCheck(bufStr, tag string) (unambiguous, ambiguous string) { + if overlapLen := overlap(bufStr, tag); overlapLen > 0 { + beforePartialTag := bufStr[:len(bufStr)-overlapLen] + trailingLen := trailingWhitespaceLen(beforePartialTag) + return bufStr[:len(beforePartialTag)-trailingLen], bufStr[len(beforePartialTag)-trailingLen:] + } + wsLen := trailingWhitespaceLen(bufStr) + return bufStr[:len(bufStr)-wsLen], bufStr[len(bufStr)-wsLen:] +} + +func (p *Nemotron3NanoParser) eat() ([]nemotronEvent, bool) { + bufStr := p.buffer.String() + if bufStr == "" { + return nil, false + } + + switch p.state { + case Nemotron3NanoCollectingThinking: + if strings.Contains(bufStr, nemotronThinkClose) { + split := strings.SplitN(bufStr, nemotronThinkClose, 2) + thinking := strings.TrimRightFunc(split[0], unicode.IsSpace) + p.buffer.Reset() + remainder := strings.TrimLeftFunc(split[1], unicode.IsSpace) + p.buffer.WriteString(remainder) + // Transition to whitespace-skipping state if buffer is empty, + // otherwise go directly to content collection + if remainder == "" { + p.state = Nemotron3NanoSkipWhitespaceAfterThinking + } else { + p.state = Nemotron3NanoCollectingContent + } + if thinking != "" { + return []nemotronEvent{nemotronEventThinkingContent{content: thinking}}, true + } + return nil, true + } + unambig, ambig := p.emitWithPartialCheck(bufStr, nemotronThinkClose) + p.buffer.Reset() + p.buffer.WriteString(ambig) + if unambig != "" { + return []nemotronEvent{nemotronEventThinkingContent{content: unambig}}, false + } + return nil, false + + // We only want to skip whitespace between thinking and content + case Nemotron3NanoSkipWhitespaceAfterThinking: + bufStr = strings.TrimLeftFunc(bufStr, unicode.IsSpace) + p.buffer.Reset() + p.buffer.WriteString(bufStr) + if bufStr == "" { + return nil, false + } + p.state = Nemotron3NanoCollectingContent + return nil, true + + case Nemotron3NanoCollectingContent: + if strings.Contains(bufStr, nemotronToolCallOpen) { + split := strings.SplitN(bufStr, nemotronToolCallOpen, 2) + content := strings.TrimRightFunc(split[0], unicode.IsSpace) + p.buffer.Reset() + p.buffer.WriteString(split[1]) + p.state = Nemotron3NanoCollectingToolCalls + if content != "" { + return []nemotronEvent{nemotronEventContent{content: content}}, true + } + return nil, true + } + unambig, ambig := p.emitWithPartialCheck(bufStr, nemotronToolCallOpen) + p.buffer.Reset() + p.buffer.WriteString(ambig) + if unambig != "" { + return []nemotronEvent{nemotronEventContent{content: unambig}}, false + } + return nil, false + + case Nemotron3NanoCollectingToolCalls: + if strings.Contains(bufStr, nemotronToolCallClose) { + split := strings.SplitN(bufStr, nemotronToolCallClose, 2) + remaining := strings.TrimLeftFunc(split[1], unicode.IsSpace) + p.buffer.Reset() + p.buffer.WriteString(remaining) + + var events []nemotronEvent + if tc, err := p.parseToolCall(split[0]); err == nil { + events = append(events, nemotronEventToolCall{toolCall: tc}) + } + + if !strings.Contains(remaining, nemotronToolCallOpen) { + p.state = Nemotron3NanoCollectingContent + } + return events, true + } + return nil, false + } + + return nil, false +} + +var ( + nemotronFunctionRegex = regexp.MustCompile(`]+)>`) + nemotronParameterRegex = regexp.MustCompile(`]+)>\n?([\s\S]*?)\n?`) +) + +func (p *Nemotron3NanoParser) parseToolCall(content string) (api.ToolCall, error) { + toolCall := api.ToolCall{} + + // Extract function name + fnMatch := nemotronFunctionRegex.FindStringSubmatch(content) + if len(fnMatch) < 2 { + return toolCall, nil + } + toolCall.Function.Name = fnMatch[1] + + // Extract parameters + toolCall.Function.Arguments = make(api.ToolCallFunctionArguments) + paramMatches := nemotronParameterRegex.FindAllStringSubmatch(content, -1) + for _, match := range paramMatches { + if len(match) >= 3 { + paramName := match[1] + paramValue := strings.TrimSpace(match[2]) + + // Try to parse as typed value based on tool definition + toolCall.Function.Arguments[paramName] = p.parseParamValue(paramName, paramValue) + } + } + + return toolCall, nil +} + +func (p *Nemotron3NanoParser) parseParamValue(paramName string, raw string) any { + // Find the matching tool to get parameter type + var paramType api.PropertyType + for _, tool := range p.tools { + if prop, ok := tool.Function.Parameters.Properties[paramName]; ok { + paramType = prop.Type + break + } + } + + return parseValue(raw, paramType) +} diff --git a/model/parsers/nemotron3nano_test.go b/model/parsers/nemotron3nano_test.go new file mode 100644 index 000000000..cfe74a314 --- /dev/null +++ b/model/parsers/nemotron3nano_test.go @@ -0,0 +1,583 @@ +package parsers + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + + "github.com/ollama/ollama/api" +) + +func TestNemotron3NanoParser(t *testing.T) { + tests := []struct { + name string + input string + thinkValue *api.ThinkValue + expectedContent string + expectedThinking string + expectedCalls []api.ToolCall + }{ + { + name: "simple content - no thinking", + input: "Hello, how can I help you?", + thinkValue: nil, + expectedContent: "Hello, how can I help you?", + }, + { + name: "simple content - thinking disabled", + input: "Hello, how can I help you?", + thinkValue: &api.ThinkValue{Value: false}, + expectedContent: "Hello, how can I help you?", + }, + { + name: "thinking then content", + input: "Let me think about this...\nHere is my answer.", + thinkValue: &api.ThinkValue{Value: true}, + expectedThinking: "Let me think about this...", + expectedContent: "Here is my answer.", + }, + { + name: "thinking with newlines", + input: "Step 1: Analyze\nStep 2: Process\nStep 3: Conclude\nThe answer is 42.", + thinkValue: &api.ThinkValue{Value: true}, + expectedThinking: "Step 1: Analyze\nStep 2: Process\nStep 3: Conclude", + expectedContent: "The answer is 42.", + }, + { + name: "simple tool call", + input: "\n\n\nParis\n\n\n", + thinkValue: nil, + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: map[string]any{"city": "Paris"}, + }, + }, + }, + }, + { + name: "content then tool call", + input: "Let me check the weather.\n\n\n\nNYC\n\n\n", + thinkValue: nil, + expectedContent: "Let me check the weather.", + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: map[string]any{"city": "NYC"}, + }, + }, + }, + }, + { + name: "tool call with multiple parameters", + input: "\n\n\nSFO\n\n\nNYC\n\n\n", + thinkValue: nil, + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "book_flight", + Arguments: map[string]any{ + "from": "SFO", + "to": "NYC", + }, + }, + }, + }, + }, + { + name: "multiple tool calls", + input: "\n\n\nSan Francisco\n\n\n\n" + + "\n\n\nNew York\n\n\n", + thinkValue: nil, + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: map[string]any{"city": "San Francisco"}, + }, + }, + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: map[string]any{"city": "New York"}, + }, + }, + }, + }, + { + name: "thinking then tool call", + input: "I should check the weather...\n\n\n\nParis\n\n\n", + thinkValue: &api.ThinkValue{Value: true}, + expectedThinking: "I should check the weather...", + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: map[string]any{"city": "Paris"}, + }, + }, + }, + }, + { + name: "thinking content then tool call", + input: "Let me think...\nI'll check for you.\n\n\n\ntest\n\n\n", + thinkValue: &api.ThinkValue{Value: true}, + expectedThinking: "Let me think...", + expectedContent: "I'll check for you.", + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "search", + Arguments: map[string]any{"query": "test"}, + }, + }, + }, + }, + { + name: "tool call with multiline parameter value", + input: "\n\n\nLine 1\nLine 2\nLine 3\n\n\n", + thinkValue: nil, + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "create_note", + Arguments: map[string]any{"content": "Line 1\nLine 2\nLine 3"}, + }, + }, + }, + }, + { + name: "empty thinking block - immediate close", + input: "\nHere is my answer.", + thinkValue: &api.ThinkValue{Value: true}, + expectedThinking: "", + expectedContent: "Here is my answer.", + }, + { + name: "thinking disabled but model outputs think close anyway", + input: "\nSome content after spurious tag.", + thinkValue: &api.ThinkValue{Value: false}, + expectedContent: "\nSome content after spurious tag.", + }, + { + name: "tool call with no function name - returns empty tool call", + input: "\n\n\n", + thinkValue: nil, + expectedCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "", Arguments: nil}}}, + }, + { + name: "content with newlines preserved", + input: "Line 1\n\nLine 2\n\n\nLine 3", + thinkValue: nil, + expectedContent: "Line 1\n\nLine 2\n\n\nLine 3", + }, + { + name: "thinking with only whitespace after close tag", + input: "My thoughts... \n\t\n Content here.", + thinkValue: &api.ThinkValue{Value: true}, + expectedThinking: "My thoughts...", + expectedContent: "Content here.", + }, + { + name: "unicode content", + input: "Hello 世界! 🌍 Ñoño", + thinkValue: nil, + expectedContent: "Hello 世界! 🌍 Ñoño", + }, + { + name: "tool call with numeric parameter", + input: "\n\n\n42\n\n\n", + thinkValue: nil, + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "set_temp", + Arguments: map[string]any{"value": "42"}, + }, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := &Nemotron3NanoParser{HasThinking: tt.thinkValue != nil && tt.thinkValue.Bool()} + p.Init(nil, nil, tt.thinkValue) + + content, thinking, calls, err := p.Add(tt.input, false) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Drain remaining content + finalContent, finalThinking, finalCalls, err := p.Add("", true) + if err != nil { + t.Fatalf("unexpected error on done: %v", err) + } + content += finalContent + thinking += finalThinking + calls = append(calls, finalCalls...) + + if diff := cmp.Diff(content, tt.expectedContent); diff != "" { + t.Errorf("content mismatch (-got +want):\n%s", diff) + } + if diff := cmp.Diff(thinking, tt.expectedThinking); diff != "" { + t.Errorf("thinking mismatch (-got +want):\n%s", diff) + } + if diff := cmp.Diff(calls, tt.expectedCalls); diff != "" { + t.Errorf("calls mismatch (-got +want):\n%s", diff) + } + }) + } +} + +func TestNemotron3NanoParser_Streaming(t *testing.T) { + tests := []struct { + name string + chunks []string + thinkValue *api.ThinkValue + expectedContent string + expectedThinking string + expectedCalls []api.ToolCall + }{ + { + name: "streaming content character by character", + chunks: []string{"H", "e", "l", "l", "o", ",", " ", "w", "o", "r", "l", "d", "!"}, + thinkValue: nil, + expectedContent: "Hello, world!", + }, + { + name: "streaming content small tokens", + chunks: []string{"Hel", "lo", ", ", "how ", "can", " I", " help", " you", " today", "?"}, + thinkValue: nil, + expectedContent: "Hello, how can I help you today?", + }, + { + name: "streaming thinking then content - granular", + chunks: []string{"Let", " me", " th", "ink", " about", " this", "...", "<", "/", "think", ">", "\n", "Here", " is", " my", " answer", "."}, + thinkValue: &api.ThinkValue{Value: true}, + expectedThinking: "Let me think about this...", + expectedContent: "Here is my answer.", + }, + { + name: "streaming thinking with newlines - granular", + chunks: []string{"Step", " 1", ":", " Ana", "lyze\n", "Step", " 2", ":", " Pro", "cess", "", "\n", "The", " ans", "wer."}, + thinkValue: &api.ThinkValue{Value: true}, + expectedThinking: "Step 1: Analyze\nStep 2: Process", + expectedContent: "The answer.", + }, + { + name: "streaming tool call - highly granular", + chunks: []string{"<", "tool", "_", "call", ">", "\n", "<", "func", "tion", "=", "get", "_", "weather", ">", "\n", "<", "param", "eter", "=", "city", ">", "\n", "Par", "is", "\n", "", "\n", "", "\n", ""}, + thinkValue: nil, + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: map[string]any{"city": "Paris"}, + }, + }, + }, + }, + { + name: "streaming content then tool call - granular", + chunks: []string{"Let", " me", " check", " the", " weather", ".", "\n<", "tool_call", ">", "\n", "", "\n", "", "\n", "NYC", "\n", "", "\n", "", "\n", ""}, + thinkValue: nil, + expectedContent: "Let me check the weather.", + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: map[string]any{"city": "NYC"}, + }, + }, + }, + }, + { + name: "tool call tag split character by character", + chunks: []string{"<", "t", "o", "o", "l", "_", "c", "a", "l", "l", ">", "\n", "<", "f", "u", "n", "c", "t", "i", "o", "n", "=", "t", "e", "s", "t", ">", "\n", "<", "/", "f", "u", "n", "c", "t", "i", "o", "n", ">", "\n", "<", "/", "t", "o", "o", "l", "_", "c", "a", "l", "l", ">"}, + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "test", + Arguments: map[string]any{}, + }, + }, + }, + }, + { + name: "thinking close tag split character by character", + chunks: []string{"I", "'", "m", " ", "t", "h", "i", "n", "k", "i", "n", "g", ".", ".", ".", "<", "/", "t", "h", "i", "n", "k", ">", "\n", "D", "o", "n", "e", "!"}, + thinkValue: &api.ThinkValue{Value: true}, + expectedThinking: "I'm thinking...", + expectedContent: "Done!", + }, + { + name: "multiple whitespace after think tag - separate chunks", + chunks: []string{"Thinking...", "", "\n", "\n", " ", "Content here."}, + thinkValue: &api.ThinkValue{Value: true}, + expectedThinking: "Thinking...", + expectedContent: "Content here.", + }, + { + name: "tool call with multiple parameters - streaming", + chunks: []string{"\n", "", "\n\n", "SFO\n", "", "\n\nNYC", "\n", "\n\n", ""}, + thinkValue: nil, + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "book_flight", + Arguments: map[string]any{ + "from": "SFO", + "to": "NYC", + }, + }, + }, + }, + }, + { + name: "thinking then content then tool call - streaming", + chunks: []string{"Ana", "lyzing", " your", " request", "...", "\n", "I'll", " check", " that", " for", " you", ".", "\n", "\n", "\n", "\n", "test", " query", "\n\n", "\n", ""}, + thinkValue: &api.ThinkValue{Value: true}, + expectedThinking: "Analyzing your request...", + expectedContent: "I'll check that for you.", + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "search", + Arguments: map[string]any{"query": "test query"}, + }, + }, + }, + }, + { + name: "multiple tool calls - streaming", + chunks: []string{ + "", "\n", "", "\n", + "\n", "San Fran", "cisco\n", "", "\n", + "", "\n", "", "\n", + "\n", "\n", + "\nNew", " York\n", "\n", + "\n", "", + }, + thinkValue: nil, + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: map[string]any{"city": "San Francisco"}, + }, + }, + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: map[string]any{"city": "New York"}, + }, + }, + }, + }, + { + name: "tool call with multiline parameter - streaming", + chunks: []string{"\n", "\n", "\n", "Line 1", "\nLine", " 2\n", "Line 3", "\n\n", "\n", ""}, + thinkValue: nil, + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "create_note", + Arguments: map[string]any{"content": "Line 1\nLine 2\nLine 3"}, + }, + }, + }, + }, + { + name: "empty thinking block", + chunks: []string{"", "\n", "Just content."}, + thinkValue: &api.ThinkValue{Value: true}, + expectedThinking: "", + expectedContent: "Just content.", + }, + { + name: "empty input chunks interspersed", + chunks: []string{"Hello", "", " ", "", "world", "", "!"}, + thinkValue: nil, + expectedContent: "Hello world!", + }, + { + name: "tool call immediately after think close - no content", + chunks: []string{"Analyzing...", "", "\n", "", "\n\n\n", ""}, + thinkValue: &api.ThinkValue{Value: true}, + expectedThinking: "Analyzing...", + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "test", + Arguments: map[string]any{}, + }, + }, + }, + }, + { + name: "tool call with empty parameter value", + chunks: []string{"\n\n\n", "\n\n\n"}, + thinkValue: nil, + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "test", + Arguments: map[string]any{"name": ""}, + }, + }, + }, + }, + { + name: "partial tool call tag at end - buffered", + chunks: []string{"Here's some content", " 0 && messages[0].Role == "system" { + systemMessage = messages[0].Content + loopMessages = messages[1:] + } else { + loopMessages = messages + } + + // Find last user message index for thinking truncation + lastUserIdx := -1 + for i, msg := range loopMessages { + if msg.Role == "user" { + lastUserIdx = i + } + } + + sb.WriteString("<|im_start|>system\n") + if systemMessage != "" { + sb.WriteString(systemMessage) + } + + if len(tools) > 0 { + if systemMessage != "" { + sb.WriteString("\n\n") + } + sb.WriteString(r.renderTools(tools)) + } + sb.WriteString("<|im_end|>\n") + + for i, message := range loopMessages { + switch message.Role { + case "assistant": + // Build content with thinking tags + content := r.buildContent(message) + shouldTruncate := i < lastUserIdx + + if len(message.ToolCalls) > 0 { + sb.WriteString("<|im_start|>assistant\n") + sb.WriteString(r.formatContent(content, shouldTruncate, true)) + r.writeToolCalls(&sb, message.ToolCalls) + sb.WriteString("<|im_end|>\n") + } else { + formatted := r.formatContent(content, shouldTruncate, false) + sb.WriteString("<|im_start|>assistant\n" + formatted + "<|im_end|>\n") + } + + case "user", "system": + sb.WriteString("<|im_start|>" + message.Role + "\n") + sb.WriteString(message.Content) + sb.WriteString("<|im_end|>\n") + + case "tool": + // Check if previous message was also a tool message + prevWasTool := i > 0 && loopMessages[i-1].Role == "tool" + nextIsTool := i+1 < len(loopMessages) && loopMessages[i+1].Role == "tool" + + if !prevWasTool { + sb.WriteString("<|im_start|>user\n") + } + sb.WriteString("\n") + sb.WriteString(message.Content) + sb.WriteString("\n\n") + + if !nextIsTool { + sb.WriteString("<|im_end|>\n") + } + + default: + sb.WriteString("<|im_start|>" + message.Role + "\n" + message.Content + "<|im_end|>\n") + } + } + + // Add generation prompt + if enableThinking { + sb.WriteString("<|im_start|>assistant\n\n") + } else { + sb.WriteString("<|im_start|>assistant\n") + } + + return sb.String(), nil +} + +func (r *Nemotron3NanoRenderer) renderTools(tools []api.Tool) string { + var sb strings.Builder + sb.WriteString("# Tools\n\nYou have access to the following functions:\n\n") + + for _, tool := range tools { + fn := tool.Function + sb.WriteString("\n\n" + fn.Name + "") + + if fn.Description != "" { + sb.WriteString("\n" + strings.TrimSpace(fn.Description) + "") + } + + sb.WriteString("\n") + if fn.Parameters.Properties != nil { + for paramName, paramFields := range fn.Parameters.Properties { + sb.WriteString("\n") + sb.WriteString("\n" + paramName + "") + + if len(paramFields.Type) > 0 { + sb.WriteString("\n" + strings.Join(paramFields.Type, ", ") + "") + } + + if paramFields.Description != "" { + sb.WriteString("\n" + strings.TrimSpace(paramFields.Description) + "") + } + + if len(paramFields.Enum) > 0 { + enumJSON, _ := json.Marshal(paramFields.Enum) + sb.WriteString("\n" + string(enumJSON) + "") + } + + sb.WriteString("\n") + } + } + + if len(fn.Parameters.Required) > 0 { + reqJSON, _ := json.Marshal(fn.Parameters.Required) + sb.WriteString("\n" + string(reqJSON) + "") + } + + sb.WriteString("\n") + sb.WriteString("\n") + } + + sb.WriteString("\n") + + sb.WriteString("\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n" + + "\n\n\nvalue_1\n\n" + + "\nThis is the value for the second parameter\nthat can span\nmultiple lines\n" + + "\n\n\n\n\nReminder:\n" + + "- Function calls MUST follow the specified format: an inner block must be nested within XML tags\n" + + "- Required parameters MUST be specified\n" + + "- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n" + + "- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n") + + return sb.String() +} + +func (r *Nemotron3NanoRenderer) buildContent(message api.Message) string { + // The parser always extracts thinking into the Thinking field, + // so Content will never have tags embedded + if message.Thinking != "" { + return "\n" + message.Thinking + "\n\n" + message.Content + } + return "" + message.Content +} + +func (r *Nemotron3NanoRenderer) formatContent(content string, truncate bool, addNewline bool) string { + if content == "" { + return "" + } + + if !truncate { + if addNewline { + return strings.TrimSpace(content) + "\n" + } + return strings.TrimSpace(content) + } + + // Truncate thinking - keep only content after + c := content + if strings.Contains(c, "") { + parts := strings.Split(c, "") + c = parts[len(parts)-1] + } else if strings.Contains(c, "") { + parts := strings.Split(c, "") + c = parts[0] + } + c = "" + strings.TrimSpace(c) + + if addNewline && len(c) > len("") { + return c + "\n" + } + if c == "" { + return c + } + return strings.TrimSpace(c) +} + +func (r *Nemotron3NanoRenderer) writeToolCalls(sb *strings.Builder, toolCalls []api.ToolCall) { + for _, tc := range toolCalls { + sb.WriteString("\n\n") + for name, value := range tc.Function.Arguments { + sb.WriteString("\n" + r.formatArgValue(value) + "\n\n") + } + sb.WriteString("\n\n") + } +} + +func (r *Nemotron3NanoRenderer) formatArgValue(value any) string { + switch v := value.(type) { + case map[string]any, []any: + jsonBytes, _ := json.Marshal(v) + return string(jsonBytes) + default: + return fmt.Sprintf("%v", v) + } +} diff --git a/model/renderers/nemotron3nano_test.go b/model/renderers/nemotron3nano_test.go new file mode 100644 index 000000000..dad528cc9 --- /dev/null +++ b/model/renderers/nemotron3nano_test.go @@ -0,0 +1,585 @@ +package renderers + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + + "github.com/ollama/ollama/api" +) + +func TestNemotron3NanoRenderer(t *testing.T) { + tests := []struct { + name string + msgs []api.Message + tools []api.Tool + thinkValue *api.ThinkValue + isThinking bool + expected string + }{ + { + name: "basic user message - thinking mode", + msgs: []api.Message{ + {Role: "user", Content: "Hello!"}, + }, + isThinking: true, + thinkValue: &api.ThinkValue{Value: true}, + expected: "<|im_start|>system\n<|im_end|>\n" + + "<|im_start|>user\nHello!<|im_end|>\n" + + "<|im_start|>assistant\n\n", + }, + { + name: "basic user message - no thinking", + msgs: []api.Message{ + {Role: "user", Content: "Hello!"}, + }, + isThinking: false, + expected: "<|im_start|>system\n<|im_end|>\n" + + "<|im_start|>user\nHello!<|im_end|>\n" + + "<|im_start|>assistant\n", + }, + { + name: "with system message", + msgs: []api.Message{ + {Role: "system", Content: "You are a helpful assistant."}, + {Role: "user", Content: "Hello!"}, + }, + isThinking: true, + thinkValue: &api.ThinkValue{Value: true}, + expected: "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" + + "<|im_start|>user\nHello!<|im_end|>\n" + + "<|im_start|>assistant\n\n", + }, + { + name: "multi-turn conversation", + msgs: []api.Message{ + {Role: "user", Content: "Hi"}, + {Role: "assistant", Content: "Hello! How can I help?"}, + {Role: "user", Content: "Tell me a joke"}, + }, + isThinking: true, + thinkValue: &api.ThinkValue{Value: true}, + expected: "<|im_start|>system\n<|im_end|>\n" + + "<|im_start|>user\nHi<|im_end|>\n" + + "<|im_start|>assistant\nHello! How can I help?<|im_end|>\n" + + "<|im_start|>user\nTell me a joke<|im_end|>\n" + + "<|im_start|>assistant\n\n", + }, + { + name: "with tools", + msgs: []api.Message{ + {Role: "user", Content: "What's the weather in Paris?"}, + }, + tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Description: "Get the current weather", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Required: []string{"city"}, + Properties: map[string]api.ToolProperty{ + "city": {Type: api.PropertyType{"string"}, Description: "The city name"}, + }, + }, + }, + }, + }, + isThinking: true, + thinkValue: &api.ThinkValue{Value: true}, + expected: "<|im_start|>system\n" + + "# Tools\n\nYou have access to the following functions:\n\n\n" + + "\nget_weather\n" + + "Get the current weather\n" + + "\n" + + "\ncity\nstring\nThe city name\n\n" + + "[\"city\"]\n" + + "\n\n\n\n" + + "If you choose to call a function ONLY reply in the following format with NO suffix:\n\n" + + "\n\n\nvalue_1\n\n" + + "\nThis is the value for the second parameter\nthat can span\nmultiple lines\n" + + "\n\n\n\n\nReminder:\n" + + "- Function calls MUST follow the specified format: an inner block must be nested within XML tags\n" + + "- Required parameters MUST be specified\n" + + "- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n" + + "- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n" + + "<|im_end|>\n" + + "<|im_start|>user\nWhat's the weather in Paris?<|im_end|>\n" + + "<|im_start|>assistant\n\n", + }, + { + name: "tool call with response", + msgs: []api.Message{ + {Role: "user", Content: "What's the weather in Paris?"}, + { + Role: "assistant", + ToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: map[string]any{"city": "Paris"}, + }, + }, + }, + }, + {Role: "tool", Content: "Sunny, 72F"}, + }, + tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Description: "Get the current weather", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Required: []string{"city"}, + Properties: map[string]api.ToolProperty{ + "city": {Type: api.PropertyType{"string"}, Description: "The city name"}, + }, + }, + }, + }, + }, + isThinking: true, + thinkValue: &api.ThinkValue{Value: true}, + expected: "<|im_start|>system\n" + + "# Tools\n\nYou have access to the following functions:\n\n\n" + + "\nget_weather\n" + + "Get the current weather\n" + + "\n" + + "\ncity\nstring\nThe city name\n\n" + + "[\"city\"]\n" + + "\n\n\n\n" + + "If you choose to call a function ONLY reply in the following format with NO suffix:\n\n" + + "\n\n\nvalue_1\n\n" + + "\nThis is the value for the second parameter\nthat can span\nmultiple lines\n" + + "\n\n\n\n\nReminder:\n" + + "- Function calls MUST follow the specified format: an inner block must be nested within XML tags\n" + + "- Required parameters MUST be specified\n" + + "- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n" + + "- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n" + + "<|im_end|>\n" + + "<|im_start|>user\nWhat's the weather in Paris?<|im_end|>\n" + + "<|im_start|>assistant\n\n" + + "\n\n\nParis\n\n\n\n<|im_end|>\n" + + "<|im_start|>user\n\nSunny, 72F\n\n<|im_end|>\n" + + "<|im_start|>assistant\n\n", + }, + { + name: "assistant with content and tool call", + msgs: []api.Message{ + {Role: "user", Content: "What's the weather?"}, + { + Role: "assistant", + Content: "Let me check that for you.", + ToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: map[string]any{"city": "Paris"}, + }, + }, + }, + }, + {Role: "tool", Content: "Sunny"}, + }, + tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Properties: map[string]api.ToolProperty{ + "city": {Type: api.PropertyType{"string"}}, + }, + }, + }, + }, + }, + isThinking: true, + thinkValue: &api.ThinkValue{Value: true}, + expected: "<|im_start|>system\n" + + "# Tools\n\nYou have access to the following functions:\n\n\n" + + "\nget_weather\n" + + "\n" + + "\ncity\nstring\n\n" + + "\n\n\n\n" + + "If you choose to call a function ONLY reply in the following format with NO suffix:\n\n" + + "\n\n\nvalue_1\n\n" + + "\nThis is the value for the second parameter\nthat can span\nmultiple lines\n" + + "\n\n\n\n\nReminder:\n" + + "- Function calls MUST follow the specified format: an inner block must be nested within XML tags\n" + + "- Required parameters MUST be specified\n" + + "- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n" + + "- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n" + + "<|im_end|>\n" + + "<|im_start|>user\nWhat's the weather?<|im_end|>\n" + + "<|im_start|>assistant\nLet me check that for you.\n" + + "\n\n\nParis\n\n\n\n<|im_end|>\n" + + "<|im_start|>user\n\nSunny\n\n<|im_end|>\n" + + "<|im_start|>assistant\n\n", + }, + { + name: "thinking in history is truncated", + msgs: []api.Message{ + {Role: "user", Content: "Hi"}, + {Role: "assistant", Content: "Hello!", Thinking: "Let me think about this..."}, + {Role: "user", Content: "How are you?"}, + }, + isThinking: true, + thinkValue: &api.ThinkValue{Value: true}, + expected: "<|im_start|>system\n<|im_end|>\n" + + "<|im_start|>user\nHi<|im_end|>\n" + + "<|im_start|>assistant\nHello!<|im_end|>\n" + + "<|im_start|>user\nHow are you?<|im_end|>\n" + + "<|im_start|>assistant\n\n", + }, + { + name: "parallel tool calls", + msgs: []api.Message{ + {Role: "user", Content: "Weather in Paris and London?"}, + { + Role: "assistant", + ToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: map[string]any{"city": "Paris"}, + }, + }, + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: map[string]any{"city": "London"}, + }, + }, + }, + }, + {Role: "tool", Content: "Sunny"}, + {Role: "tool", Content: "Rainy"}, + }, + tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Properties: map[string]api.ToolProperty{ + "city": {Type: api.PropertyType{"string"}}, + }, + }, + }, + }, + }, + isThinking: true, + thinkValue: &api.ThinkValue{Value: true}, + expected: "<|im_start|>system\n" + + "# Tools\n\nYou have access to the following functions:\n\n\n" + + "\nget_weather\n" + + "\n" + + "\ncity\nstring\n\n" + + "\n\n\n\n" + + "If you choose to call a function ONLY reply in the following format with NO suffix:\n\n" + + "\n\n\nvalue_1\n\n" + + "\nThis is the value for the second parameter\nthat can span\nmultiple lines\n" + + "\n\n\n\n\nReminder:\n" + + "- Function calls MUST follow the specified format: an inner block must be nested within XML tags\n" + + "- Required parameters MUST be specified\n" + + "- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n" + + "- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n" + + "<|im_end|>\n" + + "<|im_start|>user\nWeather in Paris and London?<|im_end|>\n" + + "<|im_start|>assistant\n\n" + + "\n\n\nParis\n\n\n\n" + + "\n\n\nLondon\n\n\n\n<|im_end|>\n" + + "<|im_start|>user\n\nSunny\n\n\nRainy\n\n<|im_end|>\n" + + "<|im_start|>assistant\n\n", + }, + { + name: "thinking disabled even when model supports it", + msgs: []api.Message{ + {Role: "user", Content: "Hello!"}, + }, + isThinking: true, // model supports thinking + thinkValue: nil, // but user didn't request it + expected: "<|im_start|>system\n<|im_end|>\n" + + "<|im_start|>user\nHello!<|im_end|>\n" + + "<|im_start|>assistant\n", + }, + { + name: "complex message history with thinking, tools, tool calls, tool results and content", + msgs: []api.Message{ + {Role: "user", Content: "What's the weather in Paris and London? Also, what's 2+2?"}, + {Role: "assistant", Content: "", Thinking: "I need to check the weather for both cities and calculate 2+2. Let me start with the weather calls.", ToolCalls: []api.ToolCall{ + {Function: api.ToolCallFunction{Name: "get_weather", Arguments: api.ToolCallFunctionArguments{"city": "Paris"}}}, + {Function: api.ToolCallFunction{Name: "get_weather", Arguments: api.ToolCallFunctionArguments{"city": "London"}}}, + }}, + {Role: "tool", Content: "Sunny, 22°C", ToolCallID: "call1"}, + {Role: "tool", Content: "Rainy, 15°C", ToolCallID: "call2"}, + {Role: "assistant", Content: "", Thinking: "Now I have the weather data. Let me calculate 2+2.", ToolCalls: []api.ToolCall{ + {Function: api.ToolCallFunction{Name: "calculate", Arguments: api.ToolCallFunctionArguments{"expression": "2+2"}}}, + }}, + {Role: "tool", Content: "4", ToolCallID: "call3"}, + {Role: "assistant", Content: "Based on the weather data, Paris is sunny at 22°C and London is rainy at 15°C. Also, 2+2 equals 4.", Thinking: "Perfect! I have all the information needed to provide a complete answer."}, + }, + tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Properties: map[string]api.ToolProperty{ + "city": {Type: api.PropertyType{"string"}}, + }, + }, + }, + }, + { + Type: "function", + Function: api.ToolFunction{ + Name: "calculate", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Properties: map[string]api.ToolProperty{ + "expression": {Type: api.PropertyType{"string"}}, + }, + }, + }, + }, + }, + isThinking: true, + thinkValue: &api.ThinkValue{Value: true}, + expected: "<|im_start|>system\n" + + "# Tools\n\nYou have access to the following functions:\n\n\n" + + "\nget_weather\n" + + "\n" + + "\ncity\nstring\n\n" + + "\n\n" + + "\ncalculate\n" + + "\n" + + "\nexpression\nstring\n\n" + + "\n\n\n\n" + + "If you choose to call a function ONLY reply in the following format with NO suffix:\n\n" + + "\n\n\nvalue_1\n\n" + + "\nThis is the value for the second parameter\nthat can span\nmultiple lines\n" + + "\n\n\n\n\nReminder:\n" + + "- Function calls MUST follow the specified format: an inner block must be nested within XML tags\n" + + "- Required parameters MUST be specified\n" + + "- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n" + + "- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n" + + "<|im_end|>\n" + + "<|im_start|>user\nWhat's the weather in Paris and London? Also, what's 2+2?<|im_end|>\n" + + "<|im_start|>assistant\n" + + "\nI need to check the weather for both cities and calculate 2+2. Let me start with the weather calls.\n\n" + + "\n\n\nParis\n\n\n\n" + + "\n\n\nLondon\n\n\n\n<|im_end|>\n" + + "<|im_start|>user\n\nSunny, 22°C\n\n\nRainy, 15°C\n\n<|im_end|>\n" + + "<|im_start|>assistant\n" + + "\nNow I have the weather data. Let me calculate 2+2.\n\n" + + "\n\n\n2+2\n\n\n\n<|im_end|>\n" + + "<|im_start|>user\n\n4\n\n<|im_end|>\n" + + "<|im_start|>assistant\n" + + "\nPerfect! I have all the information needed to provide a complete answer.\n\n" + + "Based on the weather data, Paris is sunny at 22°C and London is rainy at 15°C. Also, 2+2 equals 4.<|im_end|>\n" + + "<|im_start|>assistant\n\n", + }, + { + name: "empty messages list", + msgs: []api.Message{}, + isThinking: false, + expected: "<|im_start|>system\n<|im_end|>\n<|im_start|>assistant\n", + }, + { + name: "tool result with JSON content", + msgs: []api.Message{ + {Role: "user", Content: "Get user info"}, + { + Role: "assistant", + ToolCalls: []api.ToolCall{ + {Function: api.ToolCallFunction{Name: "get_user", Arguments: map[string]any{"id": "123"}}}, + }, + }, + {Role: "tool", Content: `{"name": "John", "age": 30, "active": true}`}, + }, + tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_user", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Properties: map[string]api.ToolProperty{"id": {Type: api.PropertyType{"string"}}}, + }, + }, + }, + }, + isThinking: true, + thinkValue: &api.ThinkValue{Value: true}, + expected: "<|im_start|>system\n" + + "# Tools\n\nYou have access to the following functions:\n\n\n" + + "\nget_user\n\n" + + "\nid\nstring\n\n" + + "\n\n\n\n" + + "If you choose to call a function ONLY reply in the following format with NO suffix:\n\n" + + "\n\n\nvalue_1\n\n" + + "\nThis is the value for the second parameter\nthat can span\nmultiple lines\n" + + "\n\n\n\n\nReminder:\n" + + "- Function calls MUST follow the specified format: an inner block must be nested within XML tags\n" + + "- Required parameters MUST be specified\n" + + "- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n" + + "- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n" + + "<|im_end|>\n" + + "<|im_start|>user\nGet user info<|im_end|>\n" + + "<|im_start|>assistant\n\n" + + "\n\n\n123\n\n\n\n<|im_end|>\n" + + "<|im_start|>user\n\n{\"name\": \"John\", \"age\": 30, \"active\": true}\n\n<|im_end|>\n" + + "<|im_start|>assistant\n\n", + }, + { + name: "assistant message with only thinking no content", + msgs: []api.Message{ + {Role: "user", Content: "Think about this"}, + {Role: "assistant", Thinking: "Deep thoughts here...", Content: ""}, + {Role: "user", Content: "What did you think?"}, + }, + isThinking: true, + thinkValue: &api.ThinkValue{Value: true}, + expected: "<|im_start|>system\n<|im_end|>\n" + + "<|im_start|>user\nThink about this<|im_end|>\n" + + "<|im_start|>assistant\n<|im_end|>\n" + + "<|im_start|>user\nWhat did you think?<|im_end|>\n" + + "<|im_start|>assistant\n\n", + }, + { + name: "tool call with complex nested argument", + msgs: []api.Message{ + {Role: "user", Content: "Create data"}, + { + Role: "assistant", + ToolCalls: []api.ToolCall{ + {Function: api.ToolCallFunction{ + Name: "create", + Arguments: map[string]any{ + "data": map[string]any{"nested": "value", "count": 42}, + }, + }}, + }, + }, + {Role: "tool", Content: "Created"}, + }, + tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "create", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Properties: map[string]api.ToolProperty{"data": {Type: api.PropertyType{"object"}}}, + }, + }, + }, + }, + isThinking: true, + thinkValue: &api.ThinkValue{Value: true}, + expected: "<|im_start|>system\n" + + "# Tools\n\nYou have access to the following functions:\n\n\n" + + "\ncreate\n\n" + + "\ndata\nobject\n\n" + + "\n\n\n\n" + + "If you choose to call a function ONLY reply in the following format with NO suffix:\n\n" + + "\n\n\nvalue_1\n\n" + + "\nThis is the value for the second parameter\nthat can span\nmultiple lines\n" + + "\n\n\n\n\nReminder:\n" + + "- Function calls MUST follow the specified format: an inner block must be nested within XML tags\n" + + "- Required parameters MUST be specified\n" + + "- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n" + + "- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n" + + "<|im_end|>\n" + + "<|im_start|>user\nCreate data<|im_end|>\n" + + "<|im_start|>assistant\n\n" + + "\n\n\n{\"count\":42,\"nested\":\"value\"}\n\n\n\n<|im_end|>\n" + + "<|im_start|>user\n\nCreated\n\n<|im_end|>\n" + + "<|im_start|>assistant\n\n", + }, + { + name: "content explaining the format itself", + msgs: []api.Message{ + {Role: "user", Content: "How do I format a tool call?"}, + {Role: "assistant", Content: "To call a tool, use tags with inside."}, + {Role: "user", Content: "Thanks!"}, + }, + isThinking: true, + thinkValue: &api.ThinkValue{Value: true}, + expected: "<|im_start|>system\n<|im_end|>\n" + + "<|im_start|>user\nHow do I format a tool call?<|im_end|>\n" + + "<|im_start|>assistant\nTo call a tool, use tags with inside.<|im_end|>\n" + + "<|im_start|>user\nThanks!<|im_end|>\n" + + "<|im_start|>assistant\n\n", + }, + { + name: "unicode in content and tool args", + msgs: []api.Message{ + {Role: "user", Content: "Translate 你好"}, + { + Role: "assistant", + ToolCalls: []api.ToolCall{ + {Function: api.ToolCallFunction{Name: "translate", Arguments: map[string]any{"text": "你好"}}}, + }, + }, + {Role: "tool", Content: "Hello"}, + }, + tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "translate", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Properties: map[string]api.ToolProperty{ + "text": {Type: api.PropertyType{"string"}}, + }, + }, + }, + }, + }, + isThinking: true, + thinkValue: &api.ThinkValue{Value: true}, + expected: "<|im_start|>system\n" + + "# Tools\n\nYou have access to the following functions:\n\n\n" + + "\ntranslate\n\n" + + "\ntext\nstring\n\n" + + "\n\n\n\n" + + "If you choose to call a function ONLY reply in the following format with NO suffix:\n\n" + + "\n\n\nvalue_1\n\n" + + "\nThis is the value for the second parameter\nthat can span\nmultiple lines\n" + + "\n\n\n\n\nReminder:\n" + + "- Function calls MUST follow the specified format: an inner block must be nested within XML tags\n" + + "- Required parameters MUST be specified\n" + + "- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n" + + "- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n" + + "<|im_end|>\n" + + "<|im_start|>user\nTranslate 你好<|im_end|>\n" + + "<|im_start|>assistant\n\n" + + "\n\n\n你好\n\n\n\n<|im_end|>\n" + + "<|im_start|>user\n\nHello\n\n<|im_end|>\n" + + "<|im_start|>assistant\n\n", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + renderer := &Nemotron3NanoRenderer{IsThinking: tt.isThinking} + rendered, err := renderer.Render(tt.msgs, tt.tools, tt.thinkValue) + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(rendered, tt.expected); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + }) + } +} diff --git a/model/renderers/renderer.go b/model/renderers/renderer.go index d05182d94..7c03cf93f 100644 --- a/model/renderers/renderer.go +++ b/model/renderers/renderer.go @@ -76,6 +76,12 @@ func rendererForName(name string) Renderer { // Used for Olmo-3-32B-Think renderer := &Olmo3ThinkRenderer{Variant: Olmo3Think32B} return renderer + case "nemotron-3-nano": + renderer := &Nemotron3NanoRenderer{IsThinking: false} + return renderer + case "nemotron-3-nano-thinking": + renderer := &Nemotron3NanoRenderer{IsThinking: true} + return renderer default: return nil }