mirror of https://github.com/ollama/ollama
Compare commits
15 Commits
v0.13.4-rc
...
main
| Author | SHA1 | Date |
|---|---|---|
|
|
a013693f80 | |
|
|
f6a016f49d | |
|
|
45c4739374 | |
|
|
2dd029de12 | |
|
|
903b1fc97f | |
|
|
89eb795293 | |
|
|
7e3ea813c1 | |
|
|
7b95087b9d | |
|
|
971d62595a | |
|
|
ffbe8e076d | |
|
|
2c639431b1 | |
|
|
aacd1cb394 | |
|
|
e3731fb160 | |
|
|
8dbc9e7b68 | |
|
|
abe67acf8a |
|
|
@ -54,6 +54,13 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cp
|
||||||
|
|
||||||
add_compile_definitions(NDEBUG GGML_VERSION=0x0 GGML_COMMIT=0x0)
|
add_compile_definitions(NDEBUG GGML_VERSION=0x0 GGML_COMMIT=0x0)
|
||||||
|
|
||||||
|
# Define GGML version variables for shared library SOVERSION
|
||||||
|
# These are required by ggml/src/CMakeLists.txt for proper library versioning
|
||||||
|
set(GGML_VERSION_MAJOR 0)
|
||||||
|
set(GGML_VERSION_MINOR 0)
|
||||||
|
set(GGML_VERSION_PATCH 0)
|
||||||
|
set(GGML_VERSION "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}")
|
||||||
|
|
||||||
set(GGML_CPU ON)
|
set(GGML_CPU ON)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src)
|
||||||
set_property(TARGET ggml PROPERTY EXCLUDE_FROM_ALL TRUE)
|
set_property(TARGET ggml PROPERTY EXCLUDE_FROM_ALL TRUE)
|
||||||
|
|
|
||||||
82
app/ui/ui.go
82
app/ui/ui.go
|
|
@ -12,13 +12,13 @@ import (
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httputil"
|
"net/http/httputil"
|
||||||
"net/url"
|
|
||||||
"os"
|
"os"
|
||||||
"runtime"
|
"runtime"
|
||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
"slices"
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
|
@ -117,40 +117,66 @@ func (s *Server) log() *slog.Logger {
|
||||||
|
|
||||||
// ollamaProxy creates a reverse proxy handler to the Ollama server
|
// ollamaProxy creates a reverse proxy handler to the Ollama server
|
||||||
func (s *Server) ollamaProxy() http.Handler {
|
func (s *Server) ollamaProxy() http.Handler {
|
||||||
ollamaHost := os.Getenv("OLLAMA_HOST")
|
var (
|
||||||
if ollamaHost == "" {
|
proxy http.Handler
|
||||||
ollamaHost = "http://127.0.0.1:11434"
|
proxyMu sync.Mutex
|
||||||
}
|
)
|
||||||
|
|
||||||
if !strings.HasPrefix(ollamaHost, "http://") && !strings.HasPrefix(ollamaHost, "https://") {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
ollamaHost = "http://" + ollamaHost
|
proxyMu.Lock()
|
||||||
}
|
p := proxy
|
||||||
|
proxyMu.Unlock()
|
||||||
|
|
||||||
target, err := url.Parse(ollamaHost)
|
if p == nil {
|
||||||
if err != nil {
|
proxyMu.Lock()
|
||||||
s.log().Error("failed to parse OLLAMA_HOST", "error", err, "host", ollamaHost)
|
if proxy == nil {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
var err error
|
||||||
http.Error(w, "failed to configure proxy", http.StatusInternalServerError)
|
for i := range 2 {
|
||||||
})
|
if i > 0 {
|
||||||
}
|
s.log().Warn("ollama server not ready, retrying", "attempt", i+1)
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
s.log().Info("configuring ollama proxy", "target", target.String())
|
err = WaitForServer(context.Background(), 10*time.Second)
|
||||||
|
if err == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
proxy := httputil.NewSingleHostReverseProxy(target)
|
if err != nil {
|
||||||
|
proxyMu.Unlock()
|
||||||
|
s.log().Error("ollama server not ready after retries", "error", err)
|
||||||
|
http.Error(w, "Ollama server is not ready", http.StatusServiceUnavailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
originalDirector := proxy.Director
|
target := envconfig.Host()
|
||||||
proxy.Director = func(req *http.Request) {
|
s.log().Info("configuring ollama proxy", "target", target.String())
|
||||||
originalDirector(req)
|
|
||||||
req.Host = target.Host
|
|
||||||
s.log().Debug("proxying request", "method", req.Method, "path", req.URL.Path, "target", target.Host)
|
|
||||||
}
|
|
||||||
|
|
||||||
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
|
newProxy := httputil.NewSingleHostReverseProxy(target)
|
||||||
s.log().Error("proxy error", "error", err, "path", r.URL.Path, "target", target.String())
|
|
||||||
http.Error(w, "proxy error: "+err.Error(), http.StatusBadGateway)
|
|
||||||
}
|
|
||||||
|
|
||||||
return proxy
|
originalDirector := newProxy.Director
|
||||||
|
newProxy.Director = func(req *http.Request) {
|
||||||
|
originalDirector(req)
|
||||||
|
req.Host = target.Host
|
||||||
|
s.log().Debug("proxying request", "method", req.Method, "path", req.URL.Path, "target", target.Host)
|
||||||
|
}
|
||||||
|
|
||||||
|
newProxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
|
||||||
|
s.log().Error("proxy error", "error", err, "path", r.URL.Path, "target", target.String())
|
||||||
|
http.Error(w, "proxy error: "+err.Error(), http.StatusBadGateway)
|
||||||
|
}
|
||||||
|
|
||||||
|
proxy = newProxy
|
||||||
|
p = newProxy
|
||||||
|
} else {
|
||||||
|
p = proxy
|
||||||
|
}
|
||||||
|
proxyMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
p.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
type errHandlerFunc func(http.ResponseWriter, *http.Request) error
|
type errHandlerFunc func(http.ResponseWriter, *http.Request) error
|
||||||
|
|
|
||||||
|
|
@ -202,6 +202,8 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
|
||||||
conv = &qwen25VLModel{}
|
conv = &qwen25VLModel{}
|
||||||
case "Qwen3VLForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration":
|
case "Qwen3VLForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration":
|
||||||
conv = &qwen3VLModel{}
|
conv = &qwen3VLModel{}
|
||||||
|
case "Olmo3ForCausalLM":
|
||||||
|
conv = &olmoModel{}
|
||||||
case "BertModel":
|
case "BertModel":
|
||||||
conv = &bertModel{}
|
conv = &bertModel{}
|
||||||
case "NomicBertModel", "NomicBertMoEModel":
|
case "NomicBertModel", "NomicBertMoEModel":
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,117 @@
|
||||||
|
package convert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"cmp"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ropeScaling struct {
|
||||||
|
Factor float32 `json:"factor"`
|
||||||
|
OriginalMaxPositionEmbeds uint32 `json:"original_max_position_embeddings"`
|
||||||
|
AttentionFactor float32 `json:"attention_factor"`
|
||||||
|
BetaFast float32 `json:"beta_fast"`
|
||||||
|
BetaSlow float32 `json:"beta_slow"`
|
||||||
|
RopeType string `json:"rope_type"`
|
||||||
|
ExtrapolationFactor float32 `json:"extrapolation_factor"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type olmoModel struct {
|
||||||
|
ModelParameters
|
||||||
|
|
||||||
|
HiddenSize uint32 `json:"hidden_size"`
|
||||||
|
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||||
|
IntermediateSize uint32 `json:"intermediate_size"`
|
||||||
|
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||||
|
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||||
|
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||||
|
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||||
|
RopeTheta float32 `json:"rope_theta"`
|
||||||
|
RopeScaling *ropeScaling `json:"rope_scaling"`
|
||||||
|
SlidingWindow uint32 `json:"sliding_window"`
|
||||||
|
LayerTypes []string `json:"layer_types"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ ModelConverter = (*olmoModel)(nil)
|
||||||
|
|
||||||
|
func (p *olmoModel) KV(t *Tokenizer) ggml.KV {
|
||||||
|
kv := p.ModelParameters.KV(t)
|
||||||
|
kv["general.architecture"] = "olmo3"
|
||||||
|
kv["olmo3.block_count"] = p.NumHiddenLayers
|
||||||
|
kv["olmo3.context_length"] = p.MaxPositionEmbeddings
|
||||||
|
kv["olmo3.embedding_length"] = p.HiddenSize
|
||||||
|
kv["olmo3.feed_forward_length"] = p.IntermediateSize
|
||||||
|
kv["olmo3.attention.head_count"] = p.NumAttentionHeads
|
||||||
|
kv["olmo3.attention.head_count_kv"] = cmp.Or(p.NumKeyValueHeads, p.NumAttentionHeads)
|
||||||
|
|
||||||
|
if p.RopeTheta > 0 {
|
||||||
|
kv["olmo3.rope.freq_base"] = p.RopeTheta
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.RopeScaling != nil {
|
||||||
|
if p.RopeScaling.Factor > 0 {
|
||||||
|
kv["olmo3.rope.scaling.factor"] = p.RopeScaling.Factor
|
||||||
|
}
|
||||||
|
if p.RopeScaling.OriginalMaxPositionEmbeds > 0 {
|
||||||
|
kv["olmo3.rope.scaling.original_context_length"] = p.RopeScaling.OriginalMaxPositionEmbeds
|
||||||
|
}
|
||||||
|
if p.RopeScaling.AttentionFactor > 0 {
|
||||||
|
kv["olmo3.rope.scaling.attn_factor"] = p.RopeScaling.AttentionFactor
|
||||||
|
}
|
||||||
|
if p.RopeScaling.RopeType != "" {
|
||||||
|
kv["olmo3.rope.scaling.type"] = p.RopeScaling.RopeType
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.RMSNormEPS > 0 {
|
||||||
|
kv["olmo3.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.SlidingWindow > 0 {
|
||||||
|
kv["olmo3.attention.sliding_window"] = p.SlidingWindow
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(p.LayerTypes) > 0 {
|
||||||
|
slidingPattern := make([]bool, len(p.LayerTypes))
|
||||||
|
for i, layerType := range p.LayerTypes {
|
||||||
|
slidingPattern[i] = (layerType == "sliding_attention")
|
||||||
|
}
|
||||||
|
kv["olmo3.attention.sliding_window_pattern"] = slidingPattern
|
||||||
|
}
|
||||||
|
|
||||||
|
return kv
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *olmoModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||||
|
out := make([]*ggml.Tensor, 0, len(ts))
|
||||||
|
for _, t := range ts {
|
||||||
|
out = append(out, &ggml.Tensor{
|
||||||
|
Name: t.Name(),
|
||||||
|
Kind: t.Kind(),
|
||||||
|
Shape: t.Shape(),
|
||||||
|
WriterTo: t,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *olmoModel) Replacements() []string {
|
||||||
|
return []string{
|
||||||
|
"lm_head", "output",
|
||||||
|
"model.embed_tokens", "token_embd",
|
||||||
|
"model.layers", "blk",
|
||||||
|
"model.norm", "output_norm",
|
||||||
|
"self_attn.q_proj", "attn_q",
|
||||||
|
"self_attn.k_proj", "attn_k",
|
||||||
|
"self_attn.v_proj", "attn_v",
|
||||||
|
"self_attn.o_proj", "attn_output",
|
||||||
|
"self_attn.q_norm", "attn_q_norm",
|
||||||
|
"self_attn.k_norm", "attn_k_norm",
|
||||||
|
"post_attention_layernorm", "post_attention_norm",
|
||||||
|
"post_feedforward_layernorm", "post_ffw_norm",
|
||||||
|
"mlp.gate_proj", "ffn_gate",
|
||||||
|
"mlp.down_proj", "ffn_down",
|
||||||
|
"mlp.up_proj", "ffn_up",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -199,7 +199,7 @@ var (
|
||||||
// MultiUserCache optimizes prompt caching for multi-user scenarios
|
// MultiUserCache optimizes prompt caching for multi-user scenarios
|
||||||
MultiUserCache = Bool("OLLAMA_MULTIUSER_CACHE")
|
MultiUserCache = Bool("OLLAMA_MULTIUSER_CACHE")
|
||||||
// Enable the new Ollama engine
|
// Enable the new Ollama engine
|
||||||
NewEngine = BoolWithDefault("OLLAMA_NEW_ENGINE")
|
NewEngine = Bool("OLLAMA_NEW_ENGINE")
|
||||||
// ContextLength sets the default context length
|
// ContextLength sets the default context length
|
||||||
ContextLength = Uint("OLLAMA_CONTEXT_LENGTH", 4096)
|
ContextLength = Uint("OLLAMA_CONTEXT_LENGTH", 4096)
|
||||||
// Auth enables authentication between the Ollama client and server
|
// Auth enables authentication between the Ollama client and server
|
||||||
|
|
@ -291,7 +291,7 @@ func AsMap() map[string]EnvVar {
|
||||||
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
|
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
|
||||||
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
|
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
|
||||||
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4096)"},
|
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4096)"},
|
||||||
"OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(true), "Enable the new Ollama engine"},
|
"OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"},
|
||||||
"OLLAMA_REMOTES": {"OLLAMA_REMOTES", Remotes(), "Allowed hosts for remote models (default \"ollama.com\")"},
|
"OLLAMA_REMOTES": {"OLLAMA_REMOTES", Remotes(), "Allowed hosts for remote models (default \"ollama.com\")"},
|
||||||
|
|
||||||
// Informational
|
// Informational
|
||||||
|
|
|
||||||
|
|
@ -241,18 +241,20 @@ func (kv KV) Bools(key string, defaultValue ...[]bool) []bool {
|
||||||
|
|
||||||
func (kv KV) OllamaEngineRequired() bool {
|
func (kv KV) OllamaEngineRequired() bool {
|
||||||
return slices.Contains([]string{
|
return slices.Contains([]string{
|
||||||
|
"bert",
|
||||||
|
"deepseek2",
|
||||||
|
"deepseekocr",
|
||||||
"gemma3",
|
"gemma3",
|
||||||
"gemma3n",
|
"gemma3n",
|
||||||
"gptoss", "gpt-oss",
|
"gptoss", "gpt-oss",
|
||||||
"llama4",
|
"llama4",
|
||||||
"mistral3",
|
"mistral3",
|
||||||
"mllama",
|
"mllama",
|
||||||
|
"nomic-bert",
|
||||||
|
"olmo3",
|
||||||
"qwen25vl",
|
"qwen25vl",
|
||||||
"qwen3", "qwen3moe",
|
"qwen3", "qwen3moe",
|
||||||
"qwen3vl", "qwen3vlmoe",
|
"qwen3vl", "qwen3vlmoe",
|
||||||
"deepseekocr",
|
|
||||||
"deepseek2",
|
|
||||||
"nomic-bert",
|
|
||||||
}, kv.Architecture())
|
}, kv.Architecture())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -838,9 +840,11 @@ func (f GGML) SupportsFlashAttention() bool {
|
||||||
// FlashAttention checks if the model should enable flash attention
|
// FlashAttention checks if the model should enable flash attention
|
||||||
func (f GGML) FlashAttention() bool {
|
func (f GGML) FlashAttention() bool {
|
||||||
return slices.Contains([]string{
|
return slices.Contains([]string{
|
||||||
|
"bert",
|
||||||
"gemma3",
|
"gemma3",
|
||||||
"gptoss", "gpt-oss",
|
"gptoss", "gpt-oss",
|
||||||
"mistral3",
|
"mistral3",
|
||||||
|
"olmo3",
|
||||||
"qwen3", "qwen3moe",
|
"qwen3", "qwen3moe",
|
||||||
"qwen3vl", "qwen3vlmoe",
|
"qwen3vl", "qwen3vlmoe",
|
||||||
}, f.KV().String("general.architecture"))
|
}, f.KV().String("general.architecture"))
|
||||||
|
|
|
||||||
|
|
@ -75,6 +75,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||||
{ LLM_ARCH_JAIS, "jais" },
|
{ LLM_ARCH_JAIS, "jais" },
|
||||||
{ LLM_ARCH_NEMOTRON, "nemotron" },
|
{ LLM_ARCH_NEMOTRON, "nemotron" },
|
||||||
{ LLM_ARCH_NEMOTRON_H, "nemotron_h" },
|
{ LLM_ARCH_NEMOTRON_H, "nemotron_h" },
|
||||||
|
{ LLM_ARCH_NEMOTRON_H_MOE, "nemotron_h_moe" },
|
||||||
{ LLM_ARCH_EXAONE, "exaone" },
|
{ LLM_ARCH_EXAONE, "exaone" },
|
||||||
{ LLM_ARCH_EXAONE4, "exaone4" },
|
{ LLM_ARCH_EXAONE4, "exaone4" },
|
||||||
{ LLM_ARCH_RWKV6, "rwkv6" },
|
{ LLM_ARCH_RWKV6, "rwkv6" },
|
||||||
|
|
@ -1765,6 +1766,39 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
LLM_ARCH_NEMOTRON_H_MOE,
|
||||||
|
{
|
||||||
|
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||||
|
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||||
|
{ LLM_TENSOR_OUTPUT, "output" },
|
||||||
|
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||||
|
// mamba(2) ssm layers
|
||||||
|
{ LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
|
||||||
|
{ LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
|
||||||
|
{ LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
|
||||||
|
{ LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
|
||||||
|
{ LLM_TENSOR_SSM_D, "blk.%d.ssm_d" },
|
||||||
|
{ LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" },
|
||||||
|
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
|
||||||
|
// attention layers
|
||||||
|
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||||
|
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||||
|
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||||
|
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||||
|
// dense FFN
|
||||||
|
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||||
|
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||||
|
// MoE FFN (for MoE layers)
|
||||||
|
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||||
|
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||||
|
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||||
|
{ LLM_TENSOR_FFN_EXP_PROBS_B,"blk.%d.exp_probs_b" },
|
||||||
|
// MoE shared expert layer
|
||||||
|
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
|
||||||
|
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
LLM_ARCH_EXAONE,
|
LLM_ARCH_EXAONE,
|
||||||
{
|
{
|
||||||
|
|
@ -2838,6 +2872,7 @@ bool llm_arch_is_hybrid(const llm_arch & arch) {
|
||||||
case LLM_ARCH_LFM2:
|
case LLM_ARCH_LFM2:
|
||||||
case LLM_ARCH_LFM2MOE:
|
case LLM_ARCH_LFM2MOE:
|
||||||
case LLM_ARCH_NEMOTRON_H:
|
case LLM_ARCH_NEMOTRON_H:
|
||||||
|
case LLM_ARCH_NEMOTRON_H_MOE:
|
||||||
case LLM_ARCH_QWEN3NEXT:
|
case LLM_ARCH_QWEN3NEXT:
|
||||||
return true;
|
return true;
|
||||||
default:
|
default:
|
||||||
|
|
|
||||||
|
|
@ -79,6 +79,7 @@ enum llm_arch {
|
||||||
LLM_ARCH_JAIS,
|
LLM_ARCH_JAIS,
|
||||||
LLM_ARCH_NEMOTRON,
|
LLM_ARCH_NEMOTRON,
|
||||||
LLM_ARCH_NEMOTRON_H,
|
LLM_ARCH_NEMOTRON_H,
|
||||||
|
LLM_ARCH_NEMOTRON_H_MOE,
|
||||||
LLM_ARCH_EXAONE,
|
LLM_ARCH_EXAONE,
|
||||||
LLM_ARCH_EXAONE4,
|
LLM_ARCH_EXAONE4,
|
||||||
LLM_ARCH_RWKV6,
|
LLM_ARCH_RWKV6,
|
||||||
|
|
|
||||||
|
|
@ -1089,6 +1089,16 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
||||||
cur = ggml_relu(ctx0, cur);
|
cur = ggml_relu(ctx0, cur);
|
||||||
cb(cur, "ffn_moe_relu", il);
|
cb(cur, "ffn_moe_relu", il);
|
||||||
} break;
|
} break;
|
||||||
|
case LLM_FFN_RELU_SQR:
|
||||||
|
if (gate_exps) {
|
||||||
|
// TODO: add support for gated squared relu
|
||||||
|
GGML_ABORT("fatal error: gated squared relu not implemented");
|
||||||
|
} else {
|
||||||
|
cur = ggml_relu(ctx0, cur);
|
||||||
|
cur = ggml_sqr(ctx0, cur);
|
||||||
|
cb(cur, "ffn_moe_relu_sqr", il);
|
||||||
|
}
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -120,6 +120,8 @@ const char * llm_type_name(llm_type type) {
|
||||||
case LLM_TYPE_16B_A1B: return "16B.A1B";
|
case LLM_TYPE_16B_A1B: return "16B.A1B";
|
||||||
case LLM_TYPE_21B_A3B: return "21B.A3B";
|
case LLM_TYPE_21B_A3B: return "21B.A3B";
|
||||||
case LLM_TYPE_30B_A3B: return "30B.A3B";
|
case LLM_TYPE_30B_A3B: return "30B.A3B";
|
||||||
|
case LLM_TYPE_31B_A3_5B: return "31B.A3.5B";
|
||||||
|
case LLM_TYPE_80B_A3B: return "80B.A3B";
|
||||||
case LLM_TYPE_100B_A6B: return "100B.A6B";
|
case LLM_TYPE_100B_A6B: return "100B.A6B";
|
||||||
case LLM_TYPE_106B_A12B: return "106B.A12B";
|
case LLM_TYPE_106B_A12B: return "106B.A12B";
|
||||||
case LLM_TYPE_230B_A10B: return "230B.A10B";
|
case LLM_TYPE_230B_A10B: return "230B.A10B";
|
||||||
|
|
@ -1788,6 +1790,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
case LLM_ARCH_NEMOTRON_H:
|
case LLM_ARCH_NEMOTRON_H:
|
||||||
|
case LLM_ARCH_NEMOTRON_H_MOE:
|
||||||
{
|
{
|
||||||
ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv);
|
ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv);
|
||||||
ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner);
|
ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner);
|
||||||
|
|
@ -1803,7 +1806,14 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||||
|
|
||||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||||
|
|
||||||
|
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false);
|
||||||
|
ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false);
|
||||||
|
ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared, false);
|
||||||
|
ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false);
|
||||||
|
ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false);
|
||||||
|
|
||||||
switch (hparams.n_layer) {
|
switch (hparams.n_layer) {
|
||||||
|
case 52: type = LLM_TYPE_31B_A3_5B; break; // Nemotron-H_MOE 31B
|
||||||
case 56: type = LLM_TYPE_9B; break;
|
case 56: type = LLM_TYPE_9B; break;
|
||||||
default: type = LLM_TYPE_UNKNOWN;
|
default: type = LLM_TYPE_UNKNOWN;
|
||||||
}
|
}
|
||||||
|
|
@ -5175,6 +5185,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
case LLM_ARCH_NEMOTRON_H:
|
case LLM_ARCH_NEMOTRON_H:
|
||||||
|
case LLM_ARCH_NEMOTRON_H_MOE:
|
||||||
{
|
{
|
||||||
// mamba2 Mixer SSM params
|
// mamba2 Mixer SSM params
|
||||||
// NOTE: int64_t for tensor dimensions
|
// NOTE: int64_t for tensor dimensions
|
||||||
|
|
@ -5185,6 +5196,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||||
const int64_t n_group = hparams.ssm_n_group;
|
const int64_t n_group = hparams.ssm_n_group;
|
||||||
const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_ssm_head;
|
const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_ssm_head;
|
||||||
|
|
||||||
|
const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used;
|
||||||
|
const int64_t n_ff_shexp = hparams.n_ff_shexp;
|
||||||
|
|
||||||
// embeddings
|
// embeddings
|
||||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||||
|
|
||||||
|
|
@ -5234,12 +5248,26 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||||
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_k_gqa_i}, TENSOR_NOT_REQUIRED);
|
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_k_gqa_i}, TENSOR_NOT_REQUIRED);
|
||||||
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_v_gqa_i}, TENSOR_NOT_REQUIRED);
|
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_v_gqa_i}, TENSOR_NOT_REQUIRED);
|
||||||
layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
|
layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
|
||||||
} else {
|
} else {
|
||||||
// mlp layers
|
if (n_expert != 0) {
|
||||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { hparams.n_ff(i), n_embd}, 0);
|
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert}, 0);
|
||||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, hparams.n_ff(i)}, 0);
|
layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert }, 0);
|
||||||
layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
|
|
||||||
layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {hparams.n_ff(i)}, TENSOR_NOT_REQUIRED);
|
// MoE branch
|
||||||
|
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0);
|
||||||
|
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
|
||||||
|
|
||||||
|
// Shared expert branch
|
||||||
|
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0);
|
||||||
|
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
// mlp layers
|
||||||
|
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { hparams.n_ff(i), n_embd}, 0);
|
||||||
|
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, hparams.n_ff(i)}, 0);
|
||||||
|
layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
|
||||||
|
layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {hparams.n_ff(i)}, TENSOR_NOT_REQUIRED);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
|
@ -6870,7 +6898,8 @@ void llama_model::print_info() const {
|
||||||
arch == LLM_ARCH_PLAMO2 ||
|
arch == LLM_ARCH_PLAMO2 ||
|
||||||
arch == LLM_ARCH_GRANITE_HYBRID ||
|
arch == LLM_ARCH_GRANITE_HYBRID ||
|
||||||
arch == LLM_ARCH_QWEN3NEXT ||
|
arch == LLM_ARCH_QWEN3NEXT ||
|
||||||
arch == LLM_ARCH_NEMOTRON_H) {
|
arch == LLM_ARCH_NEMOTRON_H ||
|
||||||
|
arch == LLM_ARCH_NEMOTRON_H_MOE) {
|
||||||
LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv);
|
LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv);
|
||||||
LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner);
|
LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner);
|
||||||
LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state);
|
LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state);
|
||||||
|
|
@ -6926,7 +6955,8 @@ void llama_model::print_info() const {
|
||||||
if (arch == LLM_ARCH_MINICPM ||
|
if (arch == LLM_ARCH_MINICPM ||
|
||||||
arch == LLM_ARCH_GRANITE ||
|
arch == LLM_ARCH_GRANITE ||
|
||||||
arch == LLM_ARCH_GRANITE_MOE ||
|
arch == LLM_ARCH_GRANITE_MOE ||
|
||||||
arch == LLM_ARCH_GRANITE_HYBRID) {
|
arch == LLM_ARCH_GRANITE_HYBRID ||
|
||||||
|
arch == LLM_ARCH_NEMOTRON_H_MOE) {
|
||||||
LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale);
|
LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale);
|
||||||
LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale);
|
LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale);
|
||||||
LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale);
|
LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale);
|
||||||
|
|
@ -7107,7 +7137,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||||
if (arch == LLM_ARCH_FALCON_H1) {
|
if (arch == LLM_ARCH_FALCON_H1) {
|
||||||
filter_attn = [&](int32_t) { return true; };
|
filter_attn = [&](int32_t) { return true; };
|
||||||
filter_recr = [&](int32_t) { return true; };
|
filter_recr = [&](int32_t) { return true; };
|
||||||
} else if (arch == LLM_ARCH_NEMOTRON_H) {
|
} else if (arch == LLM_ARCH_NEMOTRON_H || arch == LLM_ARCH_NEMOTRON_H_MOE) {
|
||||||
filter_attn = [&](int32_t il) {
|
filter_attn = [&](int32_t il) {
|
||||||
return !hparams.is_recurrent(il) && hparams.n_ff(il) == 0;
|
return !hparams.is_recurrent(il) && hparams.n_ff(il) == 0;
|
||||||
};
|
};
|
||||||
|
|
@ -7478,6 +7508,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
|
||||||
llm = std::make_unique<llm_build_nemotron>(*this, params);
|
llm = std::make_unique<llm_build_nemotron>(*this, params);
|
||||||
} break;
|
} break;
|
||||||
case LLM_ARCH_NEMOTRON_H:
|
case LLM_ARCH_NEMOTRON_H:
|
||||||
|
case LLM_ARCH_NEMOTRON_H_MOE:
|
||||||
{
|
{
|
||||||
llm = std::make_unique<llm_build_nemotron_h>(*this, params);
|
llm = std::make_unique<llm_build_nemotron_h>(*this, params);
|
||||||
} break;
|
} break;
|
||||||
|
|
@ -7765,6 +7796,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
||||||
case LLM_ARCH_ARWKV7:
|
case LLM_ARCH_ARWKV7:
|
||||||
case LLM_ARCH_WAVTOKENIZER_DEC:
|
case LLM_ARCH_WAVTOKENIZER_DEC:
|
||||||
case LLM_ARCH_NEMOTRON_H:
|
case LLM_ARCH_NEMOTRON_H:
|
||||||
|
case LLM_ARCH_NEMOTRON_H_MOE:
|
||||||
return LLAMA_ROPE_TYPE_NONE;
|
return LLAMA_ROPE_TYPE_NONE;
|
||||||
|
|
||||||
// use what we call a normal RoPE, operating on pairs of consecutive head values
|
// use what we call a normal RoPE, operating on pairs of consecutive head values
|
||||||
|
|
|
||||||
|
|
@ -114,6 +114,7 @@ enum llm_type {
|
||||||
LLM_TYPE_16B_A1B,
|
LLM_TYPE_16B_A1B,
|
||||||
LLM_TYPE_21B_A3B, // Ernie MoE small
|
LLM_TYPE_21B_A3B, // Ernie MoE small
|
||||||
LLM_TYPE_30B_A3B,
|
LLM_TYPE_30B_A3B,
|
||||||
|
LLM_TYPE_31B_A3_5B,
|
||||||
LLM_TYPE_80B_A3B, // Qwen3 Next
|
LLM_TYPE_80B_A3B, // Qwen3 Next
|
||||||
LLM_TYPE_100B_A6B,
|
LLM_TYPE_100B_A6B,
|
||||||
LLM_TYPE_106B_A12B, // GLM-4.5-Air
|
LLM_TYPE_106B_A12B, // GLM-4.5-Air
|
||||||
|
|
|
||||||
|
|
@ -107,12 +107,41 @@ ggml_tensor * llm_build_nemotron_h::build_attention_layer(ggml_tensor *
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * llm_build_nemotron_h::build_ffn_layer(ggml_tensor * cur, const llama_model & model, const int il) {
|
ggml_tensor * llm_build_nemotron_h::build_ffn_layer(ggml_tensor * cur, const llama_model & model, const int il) {
|
||||||
cur = build_ffn(cur,
|
if (model.layers[il].ffn_gate_inp == nullptr) {
|
||||||
model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
|
cur = build_ffn(cur,
|
||||||
NULL, NULL, NULL,
|
model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
|
||||||
model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
|
NULL, NULL, NULL,
|
||||||
NULL, LLM_FFN_RELU_SQR, LLM_FFN_PAR, il);
|
model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
|
||||||
cb(cur, "ffn_out", il);
|
NULL,
|
||||||
|
LLM_FFN_RELU_SQR, LLM_FFN_PAR, il);
|
||||||
|
cb(cur, "ffn_out", il);
|
||||||
|
} else {
|
||||||
|
ggml_tensor * ffn_inp = cur;
|
||||||
|
ggml_tensor * moe_out =
|
||||||
|
build_moe_ffn(ffn_inp,
|
||||||
|
model.layers[il].ffn_gate_inp,
|
||||||
|
model.layers[il].ffn_up_exps,
|
||||||
|
nullptr, // no gate
|
||||||
|
model.layers[il].ffn_down_exps,
|
||||||
|
model.layers[il].ffn_exp_probs_b,
|
||||||
|
n_expert, n_expert_used,
|
||||||
|
LLM_FFN_RELU_SQR, hparams.expert_weights_norm,
|
||||||
|
true, hparams.expert_weights_scale,
|
||||||
|
LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID,
|
||||||
|
il);
|
||||||
|
cb(moe_out, "ffn_moe_out", il);
|
||||||
|
|
||||||
|
ggml_tensor * ffn_shexp = build_ffn(ffn_inp,
|
||||||
|
model.layers[il].ffn_up_shexp, NULL, NULL,
|
||||||
|
NULL /* no gate */ , NULL, NULL,
|
||||||
|
model.layers[il].ffn_down_shexp, NULL, NULL,
|
||||||
|
NULL,
|
||||||
|
LLM_FFN_RELU_SQR, LLM_FFN_PAR, il);
|
||||||
|
cb(ffn_shexp, "ffn_shexp", il);
|
||||||
|
|
||||||
|
cur = ggml_add(ctx0, moe_out, ffn_shexp);
|
||||||
|
cb(cur, "ffn_out", il);
|
||||||
|
}
|
||||||
|
|
||||||
cur = build_cvec(cur, il);
|
cur = build_cvec(cur, il);
|
||||||
cb(cur, "l_out", il);
|
cb(cur, "l_out", il);
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,586 @@
|
||||||
|
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||||
|
From: Daniel Bevenius <daniel.bevenius@gmail.com>
|
||||||
|
Date: Mon, 15 Dec 2025 15:13:49 +0100
|
||||||
|
Subject: [PATCH] llama : add support for NVIDIA Nemotron Nano 3
|
||||||
|
|
||||||
|
This commit adds support for the NVIDIA Nemotron Nano 3 model, enabling
|
||||||
|
the conversion and running of this model.
|
||||||
|
|
||||||
|
fix indentation in llama-graph.cpp
|
||||||
|
|
||||||
|
fix indentation and move ffn_inp
|
||||||
|
|
||||||
|
convert : fix modify_tensors in NemotronHModel to call super()
|
||||||
|
|
||||||
|
fix pyright error
|
||||||
|
|
||||||
|
fix flake8 errors
|
||||||
|
---
|
||||||
|
convert_hf_to_gguf.py | 116 +++++++++++++++++++++++++++++++--
|
||||||
|
gguf-py/gguf/constants.py | 29 +++++++++
|
||||||
|
gguf-py/gguf/tensor_mapping.py | 9 ++-
|
||||||
|
src/llama-arch.cpp | 35 ++++++++++
|
||||||
|
src/llama-arch.h | 1 +
|
||||||
|
src/llama-graph.cpp | 10 +++
|
||||||
|
src/llama-model.cpp | 50 +++++++++++---
|
||||||
|
src/llama-model.h | 1 +
|
||||||
|
src/models/nemotron-h.cpp | 41 ++++++++++--
|
||||||
|
9 files changed, 269 insertions(+), 23 deletions(-)
|
||||||
|
|
||||||
|
diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py
|
||||||
|
index 867bc9053..57ec2faac 100755
|
||||||
|
--- a/convert_hf_to_gguf.py
|
||||||
|
+++ b/convert_hf_to_gguf.py
|
||||||
|
@@ -8601,8 +8601,18 @@ class GraniteHybridModel(Mamba2Model, GraniteMoeModel):
|
||||||
|
class NemotronHModel(GraniteHybridModel):
|
||||||
|
"""Hybrid mamba2/attention model from NVIDIA"""
|
||||||
|
model_arch = gguf.MODEL_ARCH.NEMOTRON_H
|
||||||
|
+ is_moe: bool = False
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
+ # We have to determine the correct model architecture (MoE vs non-MoE) before
|
||||||
|
+ # calling the parent __init__. This is because the parent constructor
|
||||||
|
+ # uses self.model_arch to build the tensor name map, and all MoE-specific
|
||||||
|
+ # mappings would be missed if it were called with the default non-MoE arch.
|
||||||
|
+ hparams = ModelBase.load_hparams(args[0], self.is_mistral_format)
|
||||||
|
+ if "num_experts_per_tok" in hparams:
|
||||||
|
+ self.model_arch = gguf.MODEL_ARCH.NEMOTRON_H_MOE
|
||||||
|
+ self.is_moe = True
|
||||||
|
+
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
# Save the top-level head_dim for later
|
||||||
|
@@ -8614,9 +8624,11 @@ class NemotronHModel(GraniteHybridModel):
|
||||||
|
|
||||||
|
# Update the ssm / attn / mlp layers
|
||||||
|
# M: Mamba2, *: Attention, -: MLP
|
||||||
|
+ # MoE:
|
||||||
|
+ # M: Mamba2, *: Attention, E: Expert
|
||||||
|
hybrid_override_pattern = self.hparams["hybrid_override_pattern"]
|
||||||
|
self._ssm_layers = [i for i, val in enumerate(hybrid_override_pattern) if val == "M"]
|
||||||
|
- self._mlp_layers = [i for i, val in enumerate(hybrid_override_pattern) if val == "-"]
|
||||||
|
+ self._mlp_layers = [i for i, val in enumerate(hybrid_override_pattern) if val == ("E" if self.is_moe else "-")]
|
||||||
|
|
||||||
|
def get_attn_layers(self):
|
||||||
|
hybrid_override_pattern = self.hparams["hybrid_override_pattern"]
|
||||||
|
@@ -8632,10 +8644,28 @@ class NemotronHModel(GraniteHybridModel):
|
||||||
|
# Set feed_forward_length
|
||||||
|
# NOTE: This will trigger an override warning. This is preferrable to
|
||||||
|
# duplicating all the parent logic
|
||||||
|
- n_ff = self.find_hparam(["intermediate_size", "n_inner", "hidden_dim"])
|
||||||
|
- self.gguf_writer.add_feed_forward_length([
|
||||||
|
- n_ff if i in self._mlp_layers else 0 for i in range(self.block_count)
|
||||||
|
- ])
|
||||||
|
+ if not self.is_moe:
|
||||||
|
+ n_ff = self.find_hparam(["intermediate_size", "n_inner", "hidden_dim"])
|
||||||
|
+ self.gguf_writer.add_feed_forward_length([
|
||||||
|
+ n_ff if i in self._mlp_layers else 0 for i in range(self.block_count)
|
||||||
|
+ ])
|
||||||
|
+ else:
|
||||||
|
+ moe_intermediate_size = self.hparams["moe_intermediate_size"]
|
||||||
|
+ self.gguf_writer.add_feed_forward_length([
|
||||||
|
+ moe_intermediate_size if i in self._mlp_layers else 0 for i in range(self.block_count)
|
||||||
|
+ ])
|
||||||
|
+ self.gguf_writer.add_expert_used_count(self.hparams["num_experts_per_tok"])
|
||||||
|
+ self.gguf_writer.add_expert_feed_forward_length(self.hparams["moe_intermediate_size"])
|
||||||
|
+ self.gguf_writer.add_expert_shared_feed_forward_length(self.hparams["moe_shared_expert_intermediate_size"])
|
||||||
|
+ self.gguf_writer.add_expert_count(self.hparams["n_routed_experts"])
|
||||||
|
+ self.gguf_writer.add_expert_shared_count(self.hparams["n_shared_experts"])
|
||||||
|
+ self.gguf_writer.add_expert_weights_norm(self.hparams["norm_topk_prob"])
|
||||||
|
+ self.gguf_writer.add_expert_weights_scale(self.hparams["routed_scaling_factor"])
|
||||||
|
+ self.gguf_writer.add_expert_group_count(self.hparams["n_group"])
|
||||||
|
+
|
||||||
|
+ # number of experts used per token (top-k)
|
||||||
|
+ if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None:
|
||||||
|
+ self.gguf_writer.add_expert_used_count(n_experts_used)
|
||||||
|
|
||||||
|
def set_vocab(self):
|
||||||
|
super().set_vocab()
|
||||||
|
@@ -8643,7 +8673,81 @@ class NemotronHModel(GraniteHybridModel):
|
||||||
|
# The tokenizer _does_ add a BOS token (via post_processor type
|
||||||
|
# TemplateProcessing) but does not set add_bos_token to true in the
|
||||||
|
# config, so we need to explicitly override it here.
|
||||||
|
- self.gguf_writer.add_add_bos_token(True)
|
||||||
|
+ if not self.is_moe:
|
||||||
|
+ self.gguf_writer.add_add_bos_token(True)
|
||||||
|
+
|
||||||
|
+ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||||
|
+ if self.is_moe and bid is not None:
|
||||||
|
+ if name.endswith("mixer.gate.e_score_correction_bias"):
|
||||||
|
+ new_name = name.replace("e_score_correction_bias", "e_score_correction_bias.bias")
|
||||||
|
+ mapped_name = self.map_tensor_name(new_name)
|
||||||
|
+ return [(mapped_name, data_torch)]
|
||||||
|
+
|
||||||
|
+ if name.endswith("mixer.dt_bias"):
|
||||||
|
+ new_name = name.replace("dt_bias", "dt.bias")
|
||||||
|
+ mapped_name = self.map_tensor_name(new_name)
|
||||||
|
+ return [(mapped_name, data_torch)]
|
||||||
|
+
|
||||||
|
+ if name.endswith("mixer.conv1d.weight"):
|
||||||
|
+ squeezed_data = data_torch.squeeze()
|
||||||
|
+ mapped_name = self.map_tensor_name(name)
|
||||||
|
+ return [(mapped_name, squeezed_data)]
|
||||||
|
+
|
||||||
|
+ if name.endswith("mixer.A_log"):
|
||||||
|
+ transformed_data = -torch.exp(data_torch)
|
||||||
|
+ reshaped_data = transformed_data.squeeze().reshape(-1, 1)
|
||||||
|
+ mapped_name = self.map_tensor_name(name)
|
||||||
|
+ return [(mapped_name, reshaped_data)]
|
||||||
|
+
|
||||||
|
+ if name.endswith("mixer.D"):
|
||||||
|
+ reshaped_data = data_torch.squeeze().reshape(-1, 1)
|
||||||
|
+ mapped_name = self.map_tensor_name(name)
|
||||||
|
+ return [(mapped_name, reshaped_data)]
|
||||||
|
+
|
||||||
|
+ if name.endswith("mixer.norm.weight"):
|
||||||
|
+ reshaped_data = data_torch.reshape(8, 512)
|
||||||
|
+ mapped_name = self.map_tensor_name(name)
|
||||||
|
+ return [(mapped_name, reshaped_data)]
|
||||||
|
+
|
||||||
|
+ if name.find("mixer.experts") != -1:
|
||||||
|
+ n_experts = self.hparams["n_routed_experts"]
|
||||||
|
+ assert bid is not None
|
||||||
|
+
|
||||||
|
+ if self._experts is None:
|
||||||
|
+ self._experts = [{} for _ in range(self.block_count)]
|
||||||
|
+
|
||||||
|
+ self._experts[bid][name] = data_torch
|
||||||
|
+
|
||||||
|
+ if len(self._experts[bid]) >= n_experts * 2:
|
||||||
|
+ # merge the experts into a single tensor
|
||||||
|
+ tensors: list[tuple[str, Tensor]] = []
|
||||||
|
+ for w_name in ["down_proj", "up_proj"]:
|
||||||
|
+ datas: list[Tensor] = []
|
||||||
|
+
|
||||||
|
+ for xid in range(n_experts):
|
||||||
|
+ ename = f"backbone.layers.{bid}.mixer.experts.{xid}.{w_name}.weight"
|
||||||
|
+ datas.append(self._experts[bid][ename])
|
||||||
|
+ del self._experts[bid][ename]
|
||||||
|
+
|
||||||
|
+ data_torch = torch.stack(datas, dim=0)
|
||||||
|
+ merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
|
||||||
|
+ new_name = self.map_tensor_name(merged_name)
|
||||||
|
+ tensors.append((new_name, data_torch))
|
||||||
|
+
|
||||||
|
+ return tensors
|
||||||
|
+ else:
|
||||||
|
+ return []
|
||||||
|
+
|
||||||
|
+ return super().modify_tensors(data_torch, name, bid)
|
||||||
|
+
|
||||||
|
+ def prepare_tensors(self):
|
||||||
|
+ super().prepare_tensors()
|
||||||
|
+
|
||||||
|
+ if self._experts is not None:
|
||||||
|
+ # flatten `list[dict[str, Tensor]]` into `list[str]`
|
||||||
|
+ experts = [k for d in self._experts for k in d.keys()]
|
||||||
|
+ if len(experts) > 0:
|
||||||
|
+ raise ValueError(f"Unprocessed experts: {experts}")
|
||||||
|
|
||||||
|
|
||||||
|
@ModelBase.register("BailingMoeForCausalLM")
|
||||||
|
diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py
|
||||||
|
index 2b8489c59..1852428b4 100644
|
||||||
|
--- a/gguf-py/gguf/constants.py
|
||||||
|
+++ b/gguf-py/gguf/constants.py
|
||||||
|
@@ -413,6 +413,7 @@ class MODEL_ARCH(IntEnum):
|
||||||
|
JAIS = auto()
|
||||||
|
NEMOTRON = auto()
|
||||||
|
NEMOTRON_H = auto()
|
||||||
|
+ NEMOTRON_H_MOE = auto()
|
||||||
|
EXAONE = auto()
|
||||||
|
EXAONE4 = auto()
|
||||||
|
GRANITE = auto()
|
||||||
|
@@ -786,6 +787,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||||
|
MODEL_ARCH.JAIS: "jais",
|
||||||
|
MODEL_ARCH.NEMOTRON: "nemotron",
|
||||||
|
MODEL_ARCH.NEMOTRON_H: "nemotron_h",
|
||||||
|
+ MODEL_ARCH.NEMOTRON_H_MOE: "nemotron_h_moe",
|
||||||
|
MODEL_ARCH.EXAONE: "exaone",
|
||||||
|
MODEL_ARCH.EXAONE4: "exaone4",
|
||||||
|
MODEL_ARCH.GRANITE: "granite",
|
||||||
|
@@ -2529,6 +2531,33 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||||
|
MODEL_TENSOR.FFN_DOWN,
|
||||||
|
MODEL_TENSOR.FFN_UP,
|
||||||
|
],
|
||||||
|
+ MODEL_ARCH.NEMOTRON_H_MOE: [
|
||||||
|
+ MODEL_TENSOR.TOKEN_EMBD,
|
||||||
|
+ MODEL_TENSOR.OUTPUT_NORM,
|
||||||
|
+ MODEL_TENSOR.OUTPUT,
|
||||||
|
+ MODEL_TENSOR.ATTN_NORM,
|
||||||
|
+ MODEL_TENSOR.SSM_IN,
|
||||||
|
+ MODEL_TENSOR.SSM_CONV1D,
|
||||||
|
+ MODEL_TENSOR.SSM_DT,
|
||||||
|
+ MODEL_TENSOR.SSM_A,
|
||||||
|
+ MODEL_TENSOR.SSM_D,
|
||||||
|
+ MODEL_TENSOR.SSM_NORM,
|
||||||
|
+ MODEL_TENSOR.SSM_OUT,
|
||||||
|
+ MODEL_TENSOR.ATTN_Q,
|
||||||
|
+ MODEL_TENSOR.ATTN_K,
|
||||||
|
+ MODEL_TENSOR.ATTN_V,
|
||||||
|
+ MODEL_TENSOR.ATTN_OUT,
|
||||||
|
+ MODEL_TENSOR.FFN_DOWN,
|
||||||
|
+ MODEL_TENSOR.FFN_UP,
|
||||||
|
+ # experts
|
||||||
|
+ MODEL_TENSOR.FFN_GATE_INP,
|
||||||
|
+ MODEL_TENSOR.FFN_UP_EXP,
|
||||||
|
+ MODEL_TENSOR.FFN_DOWN_EXP,
|
||||||
|
+ # shared expert
|
||||||
|
+ MODEL_TENSOR.FFN_DOWN_SHEXP,
|
||||||
|
+ MODEL_TENSOR.FFN_UP_SHEXP,
|
||||||
|
+ MODEL_TENSOR.FFN_EXP_PROBS_B,
|
||||||
|
+ ],
|
||||||
|
MODEL_ARCH.EXAONE: [
|
||||||
|
MODEL_TENSOR.TOKEN_EMBD,
|
||||||
|
MODEL_TENSOR.OUTPUT_NORM,
|
||||||
|
diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py
|
||||||
|
index d9c87da19..7a3c7c5e0 100644
|
||||||
|
--- a/gguf-py/gguf/tensor_mapping.py
|
||||||
|
+++ b/gguf-py/gguf/tensor_mapping.py
|
||||||
|
@@ -377,6 +377,7 @@ class TensorNameMap:
|
||||||
|
"model.layers.{bid}.feed_forward.gate", # lfm2moe
|
||||||
|
"model.layers.{bid}.mlp.router.gate", # afmoe
|
||||||
|
"layers.{bid}.gate", # mistral-large
|
||||||
|
+ "backbone.layers.{bid}.mixer.gate", # nemotron-h-moe
|
||||||
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
|
||||||
|
@@ -390,6 +391,7 @@ class TensorNameMap:
|
||||||
|
"model.layers.{bid}.mlp.expert_bias", # afmoe
|
||||||
|
"model.layers.{bid}.feed_forward.expert_bias", # lfm2moe
|
||||||
|
"model.layers.{bid}.block_sparse_moe.e_score_correction", # minimax-m2
|
||||||
|
+ "backbone.layers.{bid}.mixer.gate.e_score_correction_bias" # nemotron-h-moe
|
||||||
|
),
|
||||||
|
|
||||||
|
# Feed-forward up
|
||||||
|
@@ -438,7 +440,7 @@ class TensorNameMap:
|
||||||
|
"layers.{bid}.feed_forward.experts.w3", # mixtral (merged)
|
||||||
|
"transformer.decoder_layer.{bid}.moe.linear_v", # Grok (merged)
|
||||||
|
"transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx
|
||||||
|
- "model.layers.{bid}.mlp.experts.up_proj", # qwen2moe olmoe (merged) ernie4.5-moe
|
||||||
|
+ "model.layers.{bid}.mlp.experts.up_proj", # qwen2moe olmoe (merged) ernie4.5-moe, nemotron-h-moe (merged)
|
||||||
|
"model.layers.{bid}.block_sparse_moe.experts.w3", # phimoe (merged)
|
||||||
|
"model.layers.{bid}.feed_forward.experts.up_proj", # llama4
|
||||||
|
"encoder.layers.{bid}.mlp.experts.mlp.w1", # nomic-bert-moe
|
||||||
|
@@ -452,6 +454,7 @@ class TensorNameMap:
|
||||||
|
"model.layers.{bid}.feed_forward.down_proj",
|
||||||
|
"model.layers.{bid}.mlp.shared_mlp.up_proj", # hunyuan
|
||||||
|
"layers.{bid}.shared_experts.w3", # mistral-large
|
||||||
|
+ "backbone.layers.{bid}.mixer.shared_experts.up_proj", # nemotron-h-moe
|
||||||
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.FFN_UP_CHEXP: (
|
||||||
|
@@ -546,7 +549,7 @@ class TensorNameMap:
|
||||||
|
"layers.{bid}.feed_forward.experts.w2", # mixtral (merged)
|
||||||
|
"transformer.decoder_layer.{bid}.moe.linear_1", # Grok (merged)
|
||||||
|
"transformer.blocks.{bid}.ffn.experts.mlp.w2", # dbrx
|
||||||
|
- "model.layers.{bid}.mlp.experts.down_proj", # qwen2moe olmoe (merged) ernie4.5-moe
|
||||||
|
+ "model.layers.{bid}.mlp.experts.down_proj", # qwen2moe olmoe (merged) ernie4.5-moe nemotron-h-moe (merged)
|
||||||
|
"model.layers.{bid}.block_sparse_moe.output_linear", # granitemoe
|
||||||
|
"model.layers.{bid}.block_sparse_moe.experts.w2", # phimoe (merged)
|
||||||
|
"model.layers.{bid}.feed_forward.experts.down_proj", # llama4
|
||||||
|
@@ -561,6 +564,7 @@ class TensorNameMap:
|
||||||
|
"model.layers.{bid}.shared_mlp.output_linear", # granitemoe
|
||||||
|
"model.layers.{bid}.mlp.shared_mlp.down_proj", # hunyuan
|
||||||
|
"layers.{bid}.shared_experts.w2", # mistral-large
|
||||||
|
+ "backbone.layers.{bid}.mixer.shared_experts.down_proj", # nemotron-h-moe
|
||||||
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.FFN_DOWN_CHEXP: (
|
||||||
|
@@ -704,6 +708,7 @@ class TensorNameMap:
|
||||||
|
"model.layers.{bid}.mamba.dt_proj", # jamba falcon-h1 granite-hybrid
|
||||||
|
"model.layers.layers.{bid}.mixer.dt_proj", # plamo2
|
||||||
|
"model.layers.{bid}.linear_attn.dt_proj", # qwen3next
|
||||||
|
+ "backbone.layers.{bid}.mixer.dt", # nemotron-h-moe
|
||||||
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.SSM_DT_NORM: (
|
||||||
|
diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp
|
||||||
|
index a5fe4f66c..ac8b5e033 100644
|
||||||
|
--- a/src/llama-arch.cpp
|
||||||
|
+++ b/src/llama-arch.cpp
|
||||||
|
@@ -75,6 +75,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||||
|
{ LLM_ARCH_JAIS, "jais" },
|
||||||
|
{ LLM_ARCH_NEMOTRON, "nemotron" },
|
||||||
|
{ LLM_ARCH_NEMOTRON_H, "nemotron_h" },
|
||||||
|
+ { LLM_ARCH_NEMOTRON_H_MOE, "nemotron_h_moe" },
|
||||||
|
{ LLM_ARCH_EXAONE, "exaone" },
|
||||||
|
{ LLM_ARCH_EXAONE4, "exaone4" },
|
||||||
|
{ LLM_ARCH_RWKV6, "rwkv6" },
|
||||||
|
@@ -1765,6 +1766,39 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||||
|
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
+ {
|
||||||
|
+ LLM_ARCH_NEMOTRON_H_MOE,
|
||||||
|
+ {
|
||||||
|
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||||
|
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||||
|
+ { LLM_TENSOR_OUTPUT, "output" },
|
||||||
|
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||||
|
+ // mamba(2) ssm layers
|
||||||
|
+ { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
|
||||||
|
+ { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
|
||||||
|
+ { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
|
||||||
|
+ { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
|
||||||
|
+ { LLM_TENSOR_SSM_D, "blk.%d.ssm_d" },
|
||||||
|
+ { LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" },
|
||||||
|
+ { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
|
||||||
|
+ // attention layers
|
||||||
|
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||||
|
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||||
|
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||||
|
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||||
|
+ // dense FFN
|
||||||
|
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||||
|
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||||
|
+ // MoE FFN (for MoE layers)
|
||||||
|
+ { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||||
|
+ { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||||
|
+ { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||||
|
+ { LLM_TENSOR_FFN_EXP_PROBS_B,"blk.%d.exp_probs_b" },
|
||||||
|
+ // MoE shared expert layer
|
||||||
|
+ { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
|
||||||
|
+ { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
||||||
|
+ },
|
||||||
|
+ },
|
||||||
|
{
|
||||||
|
LLM_ARCH_EXAONE,
|
||||||
|
{
|
||||||
|
@@ -2838,6 +2872,7 @@ bool llm_arch_is_hybrid(const llm_arch & arch) {
|
||||||
|
case LLM_ARCH_LFM2:
|
||||||
|
case LLM_ARCH_LFM2MOE:
|
||||||
|
case LLM_ARCH_NEMOTRON_H:
|
||||||
|
+ case LLM_ARCH_NEMOTRON_H_MOE:
|
||||||
|
case LLM_ARCH_QWEN3NEXT:
|
||||||
|
return true;
|
||||||
|
default:
|
||||||
|
diff --git a/src/llama-arch.h b/src/llama-arch.h
|
||||||
|
index ec9e3a6df..61d73786c 100644
|
||||||
|
--- a/src/llama-arch.h
|
||||||
|
+++ b/src/llama-arch.h
|
||||||
|
@@ -79,6 +79,7 @@ enum llm_arch {
|
||||||
|
LLM_ARCH_JAIS,
|
||||||
|
LLM_ARCH_NEMOTRON,
|
||||||
|
LLM_ARCH_NEMOTRON_H,
|
||||||
|
+ LLM_ARCH_NEMOTRON_H_MOE,
|
||||||
|
LLM_ARCH_EXAONE,
|
||||||
|
LLM_ARCH_EXAONE4,
|
||||||
|
LLM_ARCH_RWKV6,
|
||||||
|
diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp
|
||||||
|
index 43620df78..763202d87 100644
|
||||||
|
--- a/src/llama-graph.cpp
|
||||||
|
+++ b/src/llama-graph.cpp
|
||||||
|
@@ -1089,6 +1089,16 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
||||||
|
cur = ggml_relu(ctx0, cur);
|
||||||
|
cb(cur, "ffn_moe_relu", il);
|
||||||
|
} break;
|
||||||
|
+ case LLM_FFN_RELU_SQR:
|
||||||
|
+ if (gate_exps) {
|
||||||
|
+ // TODO: add support for gated squared relu
|
||||||
|
+ GGML_ABORT("fatal error: gated squared relu not implemented");
|
||||||
|
+ } else {
|
||||||
|
+ cur = ggml_relu(ctx0, cur);
|
||||||
|
+ cur = ggml_sqr(ctx0, cur);
|
||||||
|
+ cb(cur, "ffn_moe_relu_sqr", il);
|
||||||
|
+ }
|
||||||
|
+ break;
|
||||||
|
default:
|
||||||
|
GGML_ABORT("fatal error");
|
||||||
|
}
|
||||||
|
diff --git a/src/llama-model.cpp b/src/llama-model.cpp
|
||||||
|
index 3c503b424..94dee78c3 100644
|
||||||
|
--- a/src/llama-model.cpp
|
||||||
|
+++ b/src/llama-model.cpp
|
||||||
|
@@ -120,6 +120,8 @@ const char * llm_type_name(llm_type type) {
|
||||||
|
case LLM_TYPE_16B_A1B: return "16B.A1B";
|
||||||
|
case LLM_TYPE_21B_A3B: return "21B.A3B";
|
||||||
|
case LLM_TYPE_30B_A3B: return "30B.A3B";
|
||||||
|
+ case LLM_TYPE_31B_A3_5B: return "31B.A3.5B";
|
||||||
|
+ case LLM_TYPE_80B_A3B: return "80B.A3B";
|
||||||
|
case LLM_TYPE_100B_A6B: return "100B.A6B";
|
||||||
|
case LLM_TYPE_106B_A12B: return "106B.A12B";
|
||||||
|
case LLM_TYPE_230B_A10B: return "230B.A10B";
|
||||||
|
@@ -1788,6 +1790,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||||
|
}
|
||||||
|
} break;
|
||||||
|
case LLM_ARCH_NEMOTRON_H:
|
||||||
|
+ case LLM_ARCH_NEMOTRON_H_MOE:
|
||||||
|
{
|
||||||
|
ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv);
|
||||||
|
ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner);
|
||||||
|
@@ -1803,7 +1806,14 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||||
|
|
||||||
|
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||||
|
|
||||||
|
+ ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false);
|
||||||
|
+ ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false);
|
||||||
|
+ ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared, false);
|
||||||
|
+ ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false);
|
||||||
|
+ ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false);
|
||||||
|
+
|
||||||
|
switch (hparams.n_layer) {
|
||||||
|
+ case 52: type = LLM_TYPE_31B_A3_5B; break; // Nemotron-H_MOE 31B
|
||||||
|
case 56: type = LLM_TYPE_9B; break;
|
||||||
|
default: type = LLM_TYPE_UNKNOWN;
|
||||||
|
}
|
||||||
|
@@ -5175,6 +5185,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||||
|
}
|
||||||
|
} break;
|
||||||
|
case LLM_ARCH_NEMOTRON_H:
|
||||||
|
+ case LLM_ARCH_NEMOTRON_H_MOE:
|
||||||
|
{
|
||||||
|
// mamba2 Mixer SSM params
|
||||||
|
// NOTE: int64_t for tensor dimensions
|
||||||
|
@@ -5185,6 +5196,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||||
|
const int64_t n_group = hparams.ssm_n_group;
|
||||||
|
const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_ssm_head;
|
||||||
|
|
||||||
|
+ const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used;
|
||||||
|
+ const int64_t n_ff_shexp = hparams.n_ff_shexp;
|
||||||
|
+
|
||||||
|
// embeddings
|
||||||
|
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||||
|
|
||||||
|
@@ -5234,12 +5248,26 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||||
|
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_k_gqa_i}, TENSOR_NOT_REQUIRED);
|
||||||
|
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_v_gqa_i}, TENSOR_NOT_REQUIRED);
|
||||||
|
layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
|
||||||
|
- } else {
|
||||||
|
- // mlp layers
|
||||||
|
- layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { hparams.n_ff(i), n_embd}, 0);
|
||||||
|
- layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, hparams.n_ff(i)}, 0);
|
||||||
|
- layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
|
||||||
|
- layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {hparams.n_ff(i)}, TENSOR_NOT_REQUIRED);
|
||||||
|
+ } else {
|
||||||
|
+ if (n_expert != 0) {
|
||||||
|
+ layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert}, 0);
|
||||||
|
+ layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert }, 0);
|
||||||
|
+
|
||||||
|
+ // MoE branch
|
||||||
|
+ layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0);
|
||||||
|
+ layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
|
||||||
|
+
|
||||||
|
+ // Shared expert branch
|
||||||
|
+ layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0);
|
||||||
|
+ layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0);
|
||||||
|
+
|
||||||
|
+ } else {
|
||||||
|
+ // mlp layers
|
||||||
|
+ layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { hparams.n_ff(i), n_embd}, 0);
|
||||||
|
+ layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, hparams.n_ff(i)}, 0);
|
||||||
|
+ layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
|
||||||
|
+ layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {hparams.n_ff(i)}, TENSOR_NOT_REQUIRED);
|
||||||
|
+ }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} break;
|
||||||
|
@@ -6870,7 +6898,8 @@ void llama_model::print_info() const {
|
||||||
|
arch == LLM_ARCH_PLAMO2 ||
|
||||||
|
arch == LLM_ARCH_GRANITE_HYBRID ||
|
||||||
|
arch == LLM_ARCH_QWEN3NEXT ||
|
||||||
|
- arch == LLM_ARCH_NEMOTRON_H) {
|
||||||
|
+ arch == LLM_ARCH_NEMOTRON_H ||
|
||||||
|
+ arch == LLM_ARCH_NEMOTRON_H_MOE) {
|
||||||
|
LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv);
|
||||||
|
LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner);
|
||||||
|
LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state);
|
||||||
|
@@ -6926,7 +6955,8 @@ void llama_model::print_info() const {
|
||||||
|
if (arch == LLM_ARCH_MINICPM ||
|
||||||
|
arch == LLM_ARCH_GRANITE ||
|
||||||
|
arch == LLM_ARCH_GRANITE_MOE ||
|
||||||
|
- arch == LLM_ARCH_GRANITE_HYBRID) {
|
||||||
|
+ arch == LLM_ARCH_GRANITE_HYBRID ||
|
||||||
|
+ arch == LLM_ARCH_NEMOTRON_H_MOE) {
|
||||||
|
LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale);
|
||||||
|
LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale);
|
||||||
|
LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale);
|
||||||
|
@@ -7107,7 +7137,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||||
|
if (arch == LLM_ARCH_FALCON_H1) {
|
||||||
|
filter_attn = [&](int32_t) { return true; };
|
||||||
|
filter_recr = [&](int32_t) { return true; };
|
||||||
|
- } else if (arch == LLM_ARCH_NEMOTRON_H) {
|
||||||
|
+ } else if (arch == LLM_ARCH_NEMOTRON_H || arch == LLM_ARCH_NEMOTRON_H_MOE) {
|
||||||
|
filter_attn = [&](int32_t il) {
|
||||||
|
return !hparams.is_recurrent(il) && hparams.n_ff(il) == 0;
|
||||||
|
};
|
||||||
|
@@ -7478,6 +7508,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
|
||||||
|
llm = std::make_unique<llm_build_nemotron>(*this, params);
|
||||||
|
} break;
|
||||||
|
case LLM_ARCH_NEMOTRON_H:
|
||||||
|
+ case LLM_ARCH_NEMOTRON_H_MOE:
|
||||||
|
{
|
||||||
|
llm = std::make_unique<llm_build_nemotron_h>(*this, params);
|
||||||
|
} break;
|
||||||
|
@@ -7765,6 +7796,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
||||||
|
case LLM_ARCH_ARWKV7:
|
||||||
|
case LLM_ARCH_WAVTOKENIZER_DEC:
|
||||||
|
case LLM_ARCH_NEMOTRON_H:
|
||||||
|
+ case LLM_ARCH_NEMOTRON_H_MOE:
|
||||||
|
return LLAMA_ROPE_TYPE_NONE;
|
||||||
|
|
||||||
|
// use what we call a normal RoPE, operating on pairs of consecutive head values
|
||||||
|
diff --git a/src/llama-model.h b/src/llama-model.h
|
||||||
|
index cbf4e1bfa..b378b23ec 100644
|
||||||
|
--- a/src/llama-model.h
|
||||||
|
+++ b/src/llama-model.h
|
||||||
|
@@ -114,6 +114,7 @@ enum llm_type {
|
||||||
|
LLM_TYPE_16B_A1B,
|
||||||
|
LLM_TYPE_21B_A3B, // Ernie MoE small
|
||||||
|
LLM_TYPE_30B_A3B,
|
||||||
|
+ LLM_TYPE_31B_A3_5B,
|
||||||
|
LLM_TYPE_80B_A3B, // Qwen3 Next
|
||||||
|
LLM_TYPE_100B_A6B,
|
||||||
|
LLM_TYPE_106B_A12B, // GLM-4.5-Air
|
||||||
|
diff --git a/src/models/nemotron-h.cpp b/src/models/nemotron-h.cpp
|
||||||
|
index 541434888..eb135e63f 100644
|
||||||
|
--- a/src/models/nemotron-h.cpp
|
||||||
|
+++ b/src/models/nemotron-h.cpp
|
||||||
|
@@ -107,12 +107,41 @@ ggml_tensor * llm_build_nemotron_h::build_attention_layer(ggml_tensor *
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor * llm_build_nemotron_h::build_ffn_layer(ggml_tensor * cur, const llama_model & model, const int il) {
|
||||||
|
- cur = build_ffn(cur,
|
||||||
|
- model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
|
||||||
|
- NULL, NULL, NULL,
|
||||||
|
- model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
|
||||||
|
- NULL, LLM_FFN_RELU_SQR, LLM_FFN_PAR, il);
|
||||||
|
- cb(cur, "ffn_out", il);
|
||||||
|
+ if (model.layers[il].ffn_gate_inp == nullptr) {
|
||||||
|
+ cur = build_ffn(cur,
|
||||||
|
+ model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
|
||||||
|
+ NULL, NULL, NULL,
|
||||||
|
+ model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
|
||||||
|
+ NULL,
|
||||||
|
+ LLM_FFN_RELU_SQR, LLM_FFN_PAR, il);
|
||||||
|
+ cb(cur, "ffn_out", il);
|
||||||
|
+ } else {
|
||||||
|
+ ggml_tensor * ffn_inp = cur;
|
||||||
|
+ ggml_tensor * moe_out =
|
||||||
|
+ build_moe_ffn(ffn_inp,
|
||||||
|
+ model.layers[il].ffn_gate_inp,
|
||||||
|
+ model.layers[il].ffn_up_exps,
|
||||||
|
+ nullptr, // no gate
|
||||||
|
+ model.layers[il].ffn_down_exps,
|
||||||
|
+ model.layers[il].ffn_exp_probs_b,
|
||||||
|
+ n_expert, n_expert_used,
|
||||||
|
+ LLM_FFN_RELU_SQR, hparams.expert_weights_norm,
|
||||||
|
+ true, hparams.expert_weights_scale,
|
||||||
|
+ LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID,
|
||||||
|
+ il);
|
||||||
|
+ cb(moe_out, "ffn_moe_out", il);
|
||||||
|
+
|
||||||
|
+ ggml_tensor * ffn_shexp = build_ffn(ffn_inp,
|
||||||
|
+ model.layers[il].ffn_up_shexp, NULL, NULL,
|
||||||
|
+ NULL /* no gate */ , NULL, NULL,
|
||||||
|
+ model.layers[il].ffn_down_shexp, NULL, NULL,
|
||||||
|
+ NULL,
|
||||||
|
+ LLM_FFN_RELU_SQR, LLM_FFN_PAR, il);
|
||||||
|
+ cb(ffn_shexp, "ffn_shexp", il);
|
||||||
|
+
|
||||||
|
+ cur = ggml_add(ctx0, moe_out, ffn_shexp);
|
||||||
|
+ cb(cur, "ffn_out", il);
|
||||||
|
+ }
|
||||||
|
|
||||||
|
cur = build_cvec(cur, il);
|
||||||
|
cb(cur, "l_out", il);
|
||||||
|
|
@ -143,7 +143,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
|
||||||
var llamaModel *llama.Model
|
var llamaModel *llama.Model
|
||||||
var textProcessor model.TextProcessor
|
var textProcessor model.TextProcessor
|
||||||
var err error
|
var err error
|
||||||
if envconfig.NewEngine(true) || f.KV().OllamaEngineRequired() {
|
if envconfig.NewEngine() || f.KV().OllamaEngineRequired() {
|
||||||
if len(projectors) == 0 {
|
if len(projectors) == 0 {
|
||||||
textProcessor, err = model.NewTextProcessor(modelPath)
|
textProcessor, err = model.NewTextProcessor(modelPath)
|
||||||
} else {
|
} else {
|
||||||
|
|
|
||||||
|
|
@ -1534,7 +1534,8 @@ func (t *Tensor) RoPE(ctx ml.Context, positions ml.Tensor, ropeDim int, ropeBase
|
||||||
unsafe.SliceData(mropeSections),
|
unsafe.SliceData(mropeSections),
|
||||||
C.int(opts.Type),
|
C.int(opts.Type),
|
||||||
cmp.Or(C.int(opts.YaRN.OriginalContextLength), 128<<10),
|
cmp.Or(C.int(opts.YaRN.OriginalContextLength), 128<<10),
|
||||||
C.float(ropeBase), C.float(ropeScale),
|
C.float(ropeBase),
|
||||||
|
C.float(ropeScale),
|
||||||
C.float(opts.YaRN.ExtrapolationFactor),
|
C.float(opts.YaRN.ExtrapolationFactor),
|
||||||
cmp.Or(C.float(opts.YaRN.AttentionFactor), 1),
|
cmp.Or(C.float(opts.YaRN.AttentionFactor), 1),
|
||||||
cmp.Or(C.float(opts.YaRN.BetaFast), 32),
|
cmp.Or(C.float(opts.YaRN.BetaFast), 32),
|
||||||
|
|
@ -1546,9 +1547,11 @@ func (t *Tensor) RoPE(ctx ml.Context, positions ml.Tensor, ropeDim int, ropeBase
|
||||||
dequant,
|
dequant,
|
||||||
positions.(*Tensor).t,
|
positions.(*Tensor).t,
|
||||||
opts.Factors.(*Tensor).t,
|
opts.Factors.(*Tensor).t,
|
||||||
C.int(ropeDim), C.int(opts.Type),
|
C.int(ropeDim),
|
||||||
|
C.int(opts.Type),
|
||||||
cmp.Or(C.int(opts.YaRN.OriginalContextLength), 128<<10),
|
cmp.Or(C.int(opts.YaRN.OriginalContextLength), 128<<10),
|
||||||
C.float(ropeBase), C.float(ropeScale),
|
C.float(ropeBase),
|
||||||
|
C.float(ropeScale),
|
||||||
C.float(opts.YaRN.ExtrapolationFactor),
|
C.float(opts.YaRN.ExtrapolationFactor),
|
||||||
cmp.Or(C.float(opts.YaRN.AttentionFactor), 1),
|
cmp.Or(C.float(opts.YaRN.AttentionFactor), 1),
|
||||||
cmp.Or(C.float(opts.YaRN.BetaFast), 32),
|
cmp.Or(C.float(opts.YaRN.BetaFast), 32),
|
||||||
|
|
|
||||||
|
|
@ -77,6 +77,13 @@ func WithMRoPE(sections []int) func(*Options) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func WithVision(sections []int) func(*Options) {
|
||||||
|
return func(opts *Options) {
|
||||||
|
opts.Type |= 1<<3 | 1<<4
|
||||||
|
opts.MRoPE.Sections = sections
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func WithInterleaveMRoPE(sections []int) func(*Options) {
|
func WithInterleaveMRoPE(sections []int) func(*Options) {
|
||||||
return func(opts *Options) {
|
return func(opts *Options) {
|
||||||
opts.Type |= 1<<3 | 1<<5
|
opts.Type |= 1<<3 | 1<<5
|
||||||
|
|
|
||||||
|
|
@ -2,9 +2,7 @@ package model
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"cmp"
|
"cmp"
|
||||||
"fmt"
|
|
||||||
"iter"
|
"iter"
|
||||||
"log/slog"
|
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
|
@ -245,14 +243,6 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||||
return ids, nil
|
return ids, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type lazyIdsString struct {
|
|
||||||
ids []int32
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l lazyIdsString) LogValue() slog.Value {
|
|
||||||
return slog.AnyValue(fmt.Sprint(l.ids))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
|
func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
for _, id := range ids {
|
for _, id := range ids {
|
||||||
|
|
@ -277,6 +267,6 @@ func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
logutil.Trace("decoded", "string", sb.String(), "from", lazyIdsString{ids: ids})
|
logutil.Trace("decoded", "string", sb.String(), "from", ids)
|
||||||
return sb.String(), nil
|
return sb.String(), nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -129,35 +129,34 @@ func (o Options) headDim() int {
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(c fs.Config) (model.Model, error) {
|
func New(c fs.Config) (model.Model, error) {
|
||||||
|
vocab := &model.Vocabulary{
|
||||||
|
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||||
|
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||||
|
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||||
|
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||||
|
BOS: []int32{
|
||||||
|
int32(cmp.Or(
|
||||||
|
c.Uint("tokenizer.ggml.cls_token_id"),
|
||||||
|
c.Uint("tokenizer.ggml.bos_token_id"),
|
||||||
|
)),
|
||||||
|
},
|
||||||
|
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", true),
|
||||||
|
EOS: []int32{
|
||||||
|
int32(cmp.Or(
|
||||||
|
c.Uint("tokenizer.ggml.separator_token_id"),
|
||||||
|
//nolint:misspell
|
||||||
|
// NOTE: "seperator_token_id" is a typo in model metadata but we need to
|
||||||
|
// support it for compatibility.
|
||||||
|
c.Uint("tokenizer.ggml.seperator_token_id"),
|
||||||
|
c.Uint("tokenizer.ggml.eos_token_id"),
|
||||||
|
)),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
var processor model.TextProcessor
|
var processor model.TextProcessor
|
||||||
switch c.String("tokenizer.ggml.model", "bert") {
|
switch c.String("tokenizer.ggml.model", "bert") {
|
||||||
case "bert":
|
case "bert":
|
||||||
processor = model.NewWordPiece(
|
processor = model.NewWordPiece(vocab, true)
|
||||||
&model.Vocabulary{
|
|
||||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
|
||||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
|
||||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
|
||||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
|
||||||
BOS: []int32{
|
|
||||||
int32(cmp.Or(
|
|
||||||
c.Uint("tokenizer.ggml.cls_token_id"),
|
|
||||||
c.Uint("tokenizer.ggml.bos_token_id"),
|
|
||||||
)),
|
|
||||||
},
|
|
||||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", true),
|
|
||||||
EOS: []int32{
|
|
||||||
int32(cmp.Or(
|
|
||||||
c.Uint("tokenizer.ggml.separator_token_id"),
|
|
||||||
//nolint:misspell
|
|
||||||
// NOTE: "seperator_token_id" is a typo in model metadata but we need to
|
|
||||||
// support it for compatibility.
|
|
||||||
c.Uint("tokenizer.ggml.seperator_token_id"),
|
|
||||||
c.Uint("tokenizer.ggml.eos_token_id"),
|
|
||||||
)),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
true,
|
|
||||||
)
|
|
||||||
default:
|
default:
|
||||||
return nil, model.ErrUnsupportedTokenizer
|
return nil, model.ErrUnsupportedTokenizer
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@ import (
|
||||||
_ "github.com/ollama/ollama/model/models/mistral3"
|
_ "github.com/ollama/ollama/model/models/mistral3"
|
||||||
_ "github.com/ollama/ollama/model/models/mllama"
|
_ "github.com/ollama/ollama/model/models/mllama"
|
||||||
_ "github.com/ollama/ollama/model/models/nomicbert"
|
_ "github.com/ollama/ollama/model/models/nomicbert"
|
||||||
|
_ "github.com/ollama/ollama/model/models/olmo3"
|
||||||
_ "github.com/ollama/ollama/model/models/qwen2"
|
_ "github.com/ollama/ollama/model/models/qwen2"
|
||||||
_ "github.com/ollama/ollama/model/models/qwen25vl"
|
_ "github.com/ollama/ollama/model/models/qwen25vl"
|
||||||
_ "github.com/ollama/ollama/model/models/qwen3"
|
_ "github.com/ollama/ollama/model/models/qwen3"
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,223 @@
|
||||||
|
package olmo3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/fs"
|
||||||
|
"github.com/ollama/ollama/kvcache"
|
||||||
|
"github.com/ollama/ollama/ml"
|
||||||
|
"github.com/ollama/ollama/ml/nn"
|
||||||
|
"github.com/ollama/ollama/ml/nn/rope"
|
||||||
|
"github.com/ollama/ollama/model"
|
||||||
|
"github.com/ollama/ollama/model/input"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
cacheTypeSWA = 0
|
||||||
|
cacheTypeCausal = 1
|
||||||
|
)
|
||||||
|
|
||||||
|
type Options struct {
|
||||||
|
hiddenSize, numHeads, numKVHeads int
|
||||||
|
eps, ropeBase, ropeScale float32
|
||||||
|
|
||||||
|
originalContextLength int
|
||||||
|
attnFactor float32
|
||||||
|
|
||||||
|
ropeType string
|
||||||
|
ropeExtrapolation float32
|
||||||
|
|
||||||
|
slidingWindowPattern []bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type Model struct {
|
||||||
|
model.Base
|
||||||
|
model.TextProcessor
|
||||||
|
|
||||||
|
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||||
|
Layers []Layer `gguf:"blk"`
|
||||||
|
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||||
|
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
||||||
|
|
||||||
|
Options
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(c fs.Config) (model.Model, error) {
|
||||||
|
vocabulary := model.Vocabulary{
|
||||||
|
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||||
|
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||||
|
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||||
|
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||||
|
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
|
||||||
|
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||||
|
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||||
|
EOS: append(
|
||||||
|
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
|
||||||
|
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
processor := model.NewBytePairEncoding(
|
||||||
|
&vocabulary,
|
||||||
|
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||||
|
)
|
||||||
|
|
||||||
|
m := Model{
|
||||||
|
TextProcessor: processor,
|
||||||
|
Layers: make([]Layer, c.Uint("block_count")),
|
||||||
|
Options: Options{
|
||||||
|
hiddenSize: int(c.Uint("embedding_length")),
|
||||||
|
numHeads: int(c.Uint("attention.head_count")),
|
||||||
|
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||||
|
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||||
|
ropeBase: c.Float("rope.freq_base", 1e4),
|
||||||
|
ropeScale: c.Float("rope.scaling.factor", 1),
|
||||||
|
originalContextLength: int(c.Uint("rope.scaling.original_context_length")),
|
||||||
|
attnFactor: c.Float("rope.scaling.attn_factor", 1),
|
||||||
|
ropeType: c.String("rope.scaling.type"),
|
||||||
|
ropeExtrapolation: c.Float("rope.scaling.extrapolation_factor", 1.0),
|
||||||
|
slidingWindowPattern: c.Bools("attention.sliding_window_pattern"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
m.Cache = kvcache.NewWrapperCache(
|
||||||
|
kvcache.NewSWACache(int32(c.Uint("attention.sliding_window")), m.Shift),
|
||||||
|
kvcache.NewCausalCache(m.Shift),
|
||||||
|
)
|
||||||
|
|
||||||
|
return &m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type SelfAttention struct {
|
||||||
|
Query *nn.Linear `gguf:"attn_q"`
|
||||||
|
Key *nn.Linear `gguf:"attn_k"`
|
||||||
|
Value *nn.Linear `gguf:"attn_v"`
|
||||||
|
Output *nn.Linear `gguf:"attn_output"`
|
||||||
|
QNorm *nn.RMSNorm `gguf:"attn_q_norm"`
|
||||||
|
KNorm *nn.RMSNorm `gguf:"attn_k_norm"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor, isSWA bool) ml.Tensor {
|
||||||
|
freqScale := float32(1.0)
|
||||||
|
ropeOpts := []func(*rope.Options){rope.WithTypeNeoX()}
|
||||||
|
|
||||||
|
if !isSWA {
|
||||||
|
freqScale = 1. / o.ropeScale
|
||||||
|
if o.originalContextLength > 0 {
|
||||||
|
ropeOpts = append(ropeOpts,
|
||||||
|
rope.WithOriginalContextLength(o.originalContextLength),
|
||||||
|
rope.WithExtrapolationFactor(o.ropeExtrapolation),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nn.RoPE(ctx, states, positions, o.hiddenSize/o.numHeads, o.ropeBase, freqScale, ropeOpts...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tensor, cache kvcache.Cache, m *Model, isSWA bool) ml.Tensor {
|
||||||
|
batchSize := hiddenState.Dim(1)
|
||||||
|
headDim := m.hiddenSize / m.numHeads
|
||||||
|
|
||||||
|
query := sa.Query.Forward(ctx, hiddenState)
|
||||||
|
query = sa.QNorm.Forward(ctx, query, m.eps)
|
||||||
|
query = query.Reshape(ctx, headDim, m.numHeads, batchSize)
|
||||||
|
query = m.Options.applyRotaryPositionEmbeddings(ctx, query, positions, isSWA)
|
||||||
|
|
||||||
|
key := sa.Key.Forward(ctx, hiddenState)
|
||||||
|
key = sa.KNorm.Forward(ctx, key, m.eps)
|
||||||
|
key = key.Reshape(ctx, headDim, m.numKVHeads, batchSize)
|
||||||
|
key = m.Options.applyRotaryPositionEmbeddings(ctx, key, positions, isSWA)
|
||||||
|
|
||||||
|
value := sa.Value.Forward(ctx, hiddenState)
|
||||||
|
value = value.Reshape(ctx, headDim, m.numKVHeads, batchSize)
|
||||||
|
|
||||||
|
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache)
|
||||||
|
attention = attention.Reshape(ctx, m.hiddenSize, batchSize)
|
||||||
|
|
||||||
|
return sa.Output.Forward(ctx, attention)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
|
isSWA := m.isSWALayer(layer)
|
||||||
|
return m.Options.applyRotaryPositionEmbeddings(ctx, key, shift, isSWA), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type MLP struct {
|
||||||
|
Up *nn.Linear `gguf:"ffn_up"`
|
||||||
|
Down *nn.Linear `gguf:"ffn_down"`
|
||||||
|
Gate *nn.Linear `gguf:"ffn_gate"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, m *Model) ml.Tensor {
|
||||||
|
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||||
|
return mlp.Down.Forward(ctx, hiddenState)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Layer struct {
|
||||||
|
SelfAttention *SelfAttention
|
||||||
|
PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_norm"`
|
||||||
|
MLP *MLP
|
||||||
|
PostFFWNorm *nn.RMSNorm `gguf:"post_ffw_norm"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Layer) Forward(ctx ml.Context, hiddenState, positions, outputs ml.Tensor, cache kvcache.Cache, m *Model, isSWA bool) ml.Tensor {
|
||||||
|
residual := hiddenState
|
||||||
|
|
||||||
|
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positions, cache, m, isSWA)
|
||||||
|
|
||||||
|
if outputs != nil {
|
||||||
|
hiddenState = hiddenState.Rows(ctx, outputs)
|
||||||
|
residual = residual.Rows(ctx, outputs)
|
||||||
|
}
|
||||||
|
hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, m.eps)
|
||||||
|
|
||||||
|
hiddenState = hiddenState.Add(ctx, residual)
|
||||||
|
residual = hiddenState
|
||||||
|
|
||||||
|
hiddenState = l.MLP.Forward(ctx, hiddenState, m)
|
||||||
|
hiddenState = l.PostFFWNorm.Forward(ctx, hiddenState, m.eps)
|
||||||
|
|
||||||
|
return hiddenState.Add(ctx, residual)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OLMo3 has Sliding Window Attention (SWA) for 3 out of every 4 layers.
|
||||||
|
func (m *Model) isSWALayer(layerIdx int) bool {
|
||||||
|
return m.Options.slidingWindowPattern[layerIdx]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
|
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
|
||||||
|
|
||||||
|
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||||
|
|
||||||
|
for i, layer := range m.Layers {
|
||||||
|
m.Cache.SetLayer(i)
|
||||||
|
cacheType := cacheTypeSWA
|
||||||
|
|
||||||
|
isSWA := m.isSWALayer(i)
|
||||||
|
if !isSWA {
|
||||||
|
cacheType = cacheTypeCausal
|
||||||
|
}
|
||||||
|
|
||||||
|
wc, ok := m.Cache.(*kvcache.WrapperCache)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("expected *kvcache.WrapperCache, got %T", m.Cache)
|
||||||
|
}
|
||||||
|
wc.SetLayerType(cacheType)
|
||||||
|
|
||||||
|
var outputs ml.Tensor
|
||||||
|
if i == len(m.Layers)-1 {
|
||||||
|
outputs = batch.Outputs
|
||||||
|
}
|
||||||
|
|
||||||
|
hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, m, isSWA)
|
||||||
|
}
|
||||||
|
|
||||||
|
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
||||||
|
return m.Output.Forward(ctx, hiddenState), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
model.Register("olmo3", New)
|
||||||
|
}
|
||||||
|
|
@ -2,7 +2,6 @@ package qwen25vl
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
|
||||||
"image"
|
"image"
|
||||||
"slices"
|
"slices"
|
||||||
|
|
||||||
|
|
@ -33,7 +32,7 @@ func New(c fs.Config) (model.Model, error) {
|
||||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
|
||||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||||
EOS: append(
|
EOS: append(
|
||||||
|
|
@ -54,19 +53,18 @@ func New(c fs.Config) (model.Model, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) PixelValues(ctx ml.Context, multimodalData []byte) (ml.Tensor, *Grid, error) {
|
func (m *Model) PixelValues(ctx ml.Context, multimodalData []byte) (ml.Tensor, *Grid, error) {
|
||||||
image, _, err := image.Decode(bytes.NewReader(multimodalData))
|
img, _, err := image.Decode(bytes.NewReader(multimodalData))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
f32s, grid, err := m.ImageProcessor.ProcessImage(image)
|
f32s, grid, err := m.ImageProcessor.ProcessImage(img)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Calculate tensor dimensions
|
// Calculate tensor dimensions
|
||||||
patchDim := m.ImageProcessor.numChannels * m.ImageProcessor.temporalPatchSize *
|
patchDim := m.numChannels * m.temporalPatchSize * m.patchSize * m.patchSize
|
||||||
m.ImageProcessor.patchSize * m.ImageProcessor.patchSize
|
|
||||||
numPatches := grid.Temporal * grid.Height * grid.Width
|
numPatches := grid.Temporal * grid.Height * grid.Width
|
||||||
|
|
||||||
pixelValues := ctx.Input().FromFloats(f32s, patchDim, numPatches)
|
pixelValues := ctx.Input().FromFloats(f32s, patchDim, numPatches)
|
||||||
|
|
@ -85,11 +83,13 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
|
||||||
}
|
}
|
||||||
|
|
||||||
visionOutputs := m.VisionModel.Forward(ctx, pixels, grid)
|
visionOutputs := m.VisionModel.Forward(ctx, pixels, grid)
|
||||||
return []input.Multimodal{{Tensor: visionOutputs}}, nil
|
return []input.Multimodal{{Tensor: visionOutputs, Data: grid}}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// PostTokenize arranges Qwen-2.5-VL's inputs for the forward pass
|
// PostTokenize arranges Qwen-2.5-VL's inputs for the forward pass
|
||||||
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||||
|
// Reset position cache
|
||||||
|
m.positionCache = m.positionCache[:0]
|
||||||
var result []*input.Input
|
var result []*input.Input
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
|
@ -98,40 +98,37 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||||
visionEndToken int32 = 151653
|
visionEndToken int32 = 151653
|
||||||
)
|
)
|
||||||
|
|
||||||
nImg := 0
|
appendInput := func(i *input.Input, p int) int {
|
||||||
|
result = append(result, i)
|
||||||
|
m.positionCache = append(m.positionCache, int32(p))
|
||||||
|
return p + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
var p int
|
||||||
for _, inp := range inputs {
|
for _, inp := range inputs {
|
||||||
if inp.Multimodal == nil {
|
if inp.Multimodal == nil {
|
||||||
// If not a multimodal input, add it to the result unchanged
|
// If not a multimodal input, add it to the result unchanged
|
||||||
result = append(result, inp)
|
p = appendInput(inp, p)
|
||||||
} else {
|
} else {
|
||||||
// Adding the 'Picture' prefix is a hack, at the time of writing there is no way to prefix
|
|
||||||
// the image tokens with a prompt, so we add a prefix here
|
|
||||||
nImg++
|
|
||||||
pre, err := m.Encode(fmt.Sprintf(" Picture %d: ", nImg), true)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to encode image prompt: %w", err)
|
|
||||||
}
|
|
||||||
for i := range pre {
|
|
||||||
result = append(result, &input.Input{Token: pre[i]})
|
|
||||||
}
|
|
||||||
|
|
||||||
patchesPerChunk := inp.Multimodal[0].Tensor.Dim(1)
|
|
||||||
|
|
||||||
// First add the vision start token
|
// First add the vision start token
|
||||||
result = append(result, &input.Input{Token: visionStartToken})
|
p = appendInput(&input.Input{Token: visionStartToken}, p)
|
||||||
|
|
||||||
// Add the image token with the multimodal tensor data at the first position
|
// Add the image token with the multimodal tensor data at the first position
|
||||||
result = append(result, &input.Input{
|
tokensPerGrid := inp.Multimodal[0].Tensor.Dim(1)
|
||||||
|
appendInput(&input.Input{
|
||||||
Token: imageToken,
|
Token: imageToken,
|
||||||
Multimodal: inp.Multimodal,
|
Multimodal: inp.Multimodal,
|
||||||
MultimodalHash: inp.MultimodalHash,
|
MultimodalHash: inp.MultimodalHash,
|
||||||
SameBatch: patchesPerChunk,
|
SameBatch: tokensPerGrid,
|
||||||
})
|
}, p)
|
||||||
|
|
||||||
// Add the placeholder tokens for the remaining positions (tokensPerGrid-1)
|
// Add the placeholder tokens for the remaining positions (tokensPerGrid-1)
|
||||||
result = append(result, slices.Repeat([]*input.Input{{Token: imageToken}}, patchesPerChunk-1)...)
|
for range tokensPerGrid - 1 {
|
||||||
|
appendInput(&input.Input{Token: imageToken}, p)
|
||||||
|
}
|
||||||
|
|
||||||
result = append(result, &input.Input{Token: visionEndToken})
|
grid := inp.Multimodal[0].Data.(*Grid)
|
||||||
|
p = appendInput(&input.Input{Token: visionEndToken}, p+max(grid.Width/m.spatialMergeSize, grid.Height/m.spatialMergeSize))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -139,9 +136,58 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
|
// Initial token embedding
|
||||||
|
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs).Duplicate(ctx)
|
||||||
|
|
||||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache)
|
positionSlice := func() [][]int32 {
|
||||||
|
s := [][]int32{
|
||||||
|
make([]int32, len(batch.Positions)),
|
||||||
|
make([]int32, len(batch.Positions)),
|
||||||
|
make([]int32, len(batch.Positions)),
|
||||||
|
make([]int32, len(batch.Positions)),
|
||||||
|
}
|
||||||
|
for i, position := range batch.Positions {
|
||||||
|
if position < int32(len(m.positionCache)) {
|
||||||
|
position = m.positionCache[position]
|
||||||
|
} else if len(m.positionCache) > 0 {
|
||||||
|
position = position - int32(len(m.positionCache)) + m.positionCache[len(m.positionCache)-1] + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
s[0][i] = position
|
||||||
|
s[1][i] = position
|
||||||
|
s[2][i] = position
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}()
|
||||||
|
|
||||||
|
for _, mi := range batch.Multimodal {
|
||||||
|
img := mi.Multimodal[0].Tensor
|
||||||
|
ctx.Forward(img.Copy(ctx, hiddenStates.View(ctx, mi.Index*hiddenStates.Stride(1), img.Dim(0)*img.Dim(1))))
|
||||||
|
if grid, ok := mi.Multimodal[0].Data.(*Grid); ok {
|
||||||
|
for i := range img.Dim(1) {
|
||||||
|
w := grid.Width / m.spatialMergeSize
|
||||||
|
positionSlice[1][mi.Index+i] += int32(i / w)
|
||||||
|
positionSlice[2][mi.Index+i] += int32(i % w)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
positions := ctx.Input().FromInts(slices.Concat(positionSlice...), len(positionSlice[0])*len(positionSlice))
|
||||||
|
|
||||||
|
// Process through transformer layers
|
||||||
|
for i, layer := range m.TextModel.Layers {
|
||||||
|
m.Cache.SetLayer(i)
|
||||||
|
|
||||||
|
var lastLayerOutputs ml.Tensor
|
||||||
|
if i == len(m.TextModel.Layers)-1 {
|
||||||
|
lastLayerOutputs = batch.Outputs
|
||||||
|
}
|
||||||
|
|
||||||
|
hiddenStates = layer.Forward(ctx, hiddenStates, positions, lastLayerOutputs, m.Cache, m.TextOptions)
|
||||||
|
}
|
||||||
|
|
||||||
|
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.TextModel.eps)
|
||||||
|
return m.Output.Forward(ctx, hiddenStates), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
|
|
||||||
|
|
@ -8,20 +8,17 @@ import (
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
"github.com/ollama/ollama/ml/nn"
|
"github.com/ollama/ollama/ml/nn"
|
||||||
"github.com/ollama/ollama/ml/nn/rope"
|
"github.com/ollama/ollama/ml/nn/rope"
|
||||||
"github.com/ollama/ollama/model/input"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type TextOptions struct {
|
type TextOptions struct {
|
||||||
hiddenSize, numHeads, numKVHeads int
|
hiddenSize, numHeads, numKVHeads int
|
||||||
ropeDim, originalContextLength int
|
ropeDim, originalContextLength int
|
||||||
eps, ropeBase, ropeScale float32
|
eps, ropeBase, ropeScale float32
|
||||||
|
mropeSections []int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o TextOptions) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
|
func (o TextOptions) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
|
||||||
return nn.RoPE(ctx, states, positions, o.ropeDim, o.ropeBase, 1./o.ropeScale,
|
return nn.RoPE(ctx, states, positions, o.ropeDim, o.ropeBase, 1./o.ropeScale, rope.WithMRoPE(o.mropeSections))
|
||||||
rope.WithOriginalContextLength(o.originalContextLength),
|
|
||||||
rope.WithTypeNeoX(),
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type TextModel struct {
|
type TextModel struct {
|
||||||
|
|
@ -31,6 +28,7 @@ type TextModel struct {
|
||||||
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
||||||
|
|
||||||
*TextOptions
|
*TextOptions
|
||||||
|
positionCache []int32
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTextModel(c fs.Config) *TextModel {
|
func NewTextModel(c fs.Config) *TextModel {
|
||||||
|
|
@ -45,6 +43,14 @@ func NewTextModel(c fs.Config) *TextModel {
|
||||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||||
ropeBase: c.Float("rope.freq_base"),
|
ropeBase: c.Float("rope.freq_base"),
|
||||||
ropeScale: c.Float("rope.scaling.factor", 1),
|
ropeScale: c.Float("rope.scaling.factor", 1),
|
||||||
|
mropeSections: func() []int {
|
||||||
|
sections := c.Ints("rope.mrope_section")
|
||||||
|
s := make([]int, len(sections))
|
||||||
|
for i, section := range sections {
|
||||||
|
s[i] = int(section)
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}(),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -84,6 +90,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||||
|
|
||||||
// Shift applies rotary position embeddings to the key tensor for causal attention caching
|
// Shift applies rotary position embeddings to the key tensor for causal attention caching
|
||||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
|
m.positionCache = nil
|
||||||
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
|
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -130,28 +137,3 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
|
||||||
hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
|
hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
|
||||||
return hiddenState.Add(ctx, residual)
|
return hiddenState.Add(ctx, residual)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) (ml.Tensor, error) {
|
|
||||||
// Initial token embedding
|
|
||||||
hiddenStates := m.TokenEmbedding.Forward(ctx, inputs).Duplicate(ctx)
|
|
||||||
|
|
||||||
for _, mi := range batch.Multimodal {
|
|
||||||
img := mi.Multimodal[0].Tensor
|
|
||||||
ctx.Forward(img.Copy(ctx, hiddenStates.View(ctx, mi.Index*hiddenStates.Stride(1), img.Dim(0)*img.Dim(1))))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Process through transformer layers
|
|
||||||
for i, layer := range m.Layers {
|
|
||||||
cache.SetLayer(i)
|
|
||||||
|
|
||||||
var lastLayerOutputs ml.Tensor
|
|
||||||
if i == len(m.Layers)-1 {
|
|
||||||
lastLayerOutputs = outputs
|
|
||||||
}
|
|
||||||
|
|
||||||
hiddenStates = layer.Forward(ctx, hiddenStates, positions, lastLayerOutputs, cache, m.TextOptions)
|
|
||||||
}
|
|
||||||
|
|
||||||
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
|
|
||||||
return m.Output.Forward(ctx, hiddenStates), nil
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -7,48 +7,28 @@ import (
|
||||||
"github.com/ollama/ollama/fs"
|
"github.com/ollama/ollama/fs"
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
"github.com/ollama/ollama/ml/nn"
|
"github.com/ollama/ollama/ml/nn"
|
||||||
|
"github.com/ollama/ollama/ml/nn/rope"
|
||||||
)
|
)
|
||||||
|
|
||||||
// We only support batch size of 1
|
func blockDiagonalMask(ctx ml.Context, seqLength int, bounds []int) ml.Tensor {
|
||||||
var batchSize int = 1
|
// Initialize a 2D mask with -Inf
|
||||||
|
s := make([][]float32, seqLength)
|
||||||
func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor {
|
for i := range s {
|
||||||
x1 := t.Slice(ctx, 0, 0, t.Dim(0)/2, 1)
|
s[i] = slices.Repeat([]float32{float32(math.Inf(-1))}, seqLength)
|
||||||
x2 := t.Slice(ctx, 0, t.Dim(0)/2, t.Dim(0), 1).Contiguous(ctx)
|
|
||||||
return x2.Scale(ctx, -1).Concat(ctx, x1, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
func applyRotaryPositionEmbeddings(ctx ml.Context, states, cos, sin ml.Tensor) ml.Tensor {
|
|
||||||
return states.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, states).Mul(ctx, sin))
|
|
||||||
}
|
|
||||||
|
|
||||||
func blockDiagonalMask(ctx ml.Context, seqLength int, bounds []int, numHeads int) ml.Tensor {
|
|
||||||
// Create a flat slice for the mask (all -inf initially to block all attention)
|
|
||||||
flat := make([]float32, seqLength*seqLength)
|
|
||||||
for i := range flat {
|
|
||||||
flat[i] = float32(math.Inf(-1)) // Negative infinity to block attention
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fill in the mask with zeros for tokens that CAN attend to each other
|
// Fill in the mask with zeros for tokens that CAN attend to each other
|
||||||
for i := 1; i < len(bounds); i++ {
|
for i := 1; i < len(bounds); i++ {
|
||||||
start := bounds[i-1]
|
start, end := bounds[i-1], bounds[i]
|
||||||
end := bounds[i]
|
// Enable attention within this sequence block
|
||||||
|
|
||||||
// Enable attention within this sequence block by setting values to 0
|
|
||||||
for row := start; row < end; row++ {
|
for row := start; row < end; row++ {
|
||||||
for col := start; col < end; col++ {
|
for col := start; col < end; col++ {
|
||||||
idx := row*seqLength + col
|
s[row][col] = 0.0
|
||||||
flat[idx] = 0.0 // 0 allows attention, -inf blocks it
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
mask := ctx.Input().FromFloats(flat, seqLength, seqLength)
|
return ctx.Input().FromFloats(slices.Concat(s...), seqLength, seqLength)
|
||||||
|
|
||||||
// Reshape to match [seqLength, seqLength, 1] for broadcasting
|
|
||||||
mask = mask.Reshape(ctx, seqLength, seqLength, 1)
|
|
||||||
|
|
||||||
return mask
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type VisionSelfAttention struct {
|
type VisionSelfAttention struct {
|
||||||
|
|
@ -58,17 +38,17 @@ type VisionSelfAttention struct {
|
||||||
Output *nn.Linear `gguf:"attn_out"`
|
Output *nn.Linear `gguf:"attn_out"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin, mask ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, positions, mask ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||||
query := sa.Query.Forward(ctx, hiddenStates)
|
query := sa.Query.Forward(ctx, hiddenStates)
|
||||||
key := sa.Key.Forward(ctx, hiddenStates)
|
key := sa.Key.Forward(ctx, hiddenStates)
|
||||||
value := sa.Value.Forward(ctx, hiddenStates)
|
value := sa.Value.Forward(ctx, hiddenStates)
|
||||||
|
|
||||||
query = query.Reshape(ctx, opts.headDim, opts.numHeads, query.Dim(1), batchSize)
|
query = query.Reshape(ctx, opts.headDim, opts.numHeads, query.Dim(1))
|
||||||
key = key.Reshape(ctx, opts.headDim, opts.numHeads, key.Dim(1), batchSize)
|
key = key.Reshape(ctx, opts.headDim, opts.numHeads, key.Dim(1))
|
||||||
value = value.Reshape(ctx, opts.headDim, opts.numHeads, value.Dim(1), batchSize)
|
value = value.Reshape(ctx, opts.headDim, opts.numHeads, value.Dim(1))
|
||||||
|
|
||||||
query = applyRotaryPositionEmbeddings(ctx, query, cos, sin)
|
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
|
||||||
key = applyRotaryPositionEmbeddings(ctx, key, cos, sin)
|
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
|
||||||
|
|
||||||
// Scale factor for scaled dot-product attention
|
// Scale factor for scaled dot-product attention
|
||||||
scale := 1.0 / math.Sqrt(float64(opts.headDim))
|
scale := 1.0 / math.Sqrt(float64(opts.headDim))
|
||||||
|
|
@ -77,6 +57,7 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin, m
|
||||||
query = query.Permute(ctx, 0, 2, 1, 3)
|
query = query.Permute(ctx, 0, 2, 1, 3)
|
||||||
key = key.Permute(ctx, 0, 2, 1, 3)
|
key = key.Permute(ctx, 0, 2, 1, 3)
|
||||||
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||||
|
|
||||||
kq := key.MulmatFullPrec(ctx, query)
|
kq := key.MulmatFullPrec(ctx, query)
|
||||||
kq = kq.Scale(ctx, scale)
|
kq = kq.Scale(ctx, scale)
|
||||||
if mask != nil {
|
if mask != nil {
|
||||||
|
|
@ -85,7 +66,7 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin, m
|
||||||
kq = kq.Softmax(ctx)
|
kq = kq.Softmax(ctx)
|
||||||
kqv := value.Mulmat(ctx, kq)
|
kqv := value.Mulmat(ctx, kq)
|
||||||
attention := kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
attention := kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||||
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
|
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2))
|
||||||
|
|
||||||
return sa.Output.Forward(ctx, attention)
|
return sa.Output.Forward(ctx, attention)
|
||||||
}
|
}
|
||||||
|
|
@ -98,10 +79,7 @@ type VisionMLP struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||||
// Using activation as specified in config (likely GELU or SiLU/Swish)
|
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||||
gateOutput := mlp.Gate.Forward(ctx, hiddenStates)
|
|
||||||
hiddenStates = gateOutput.SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
|
||||||
|
|
||||||
return mlp.Down.Forward(ctx, hiddenStates)
|
return mlp.Down.Forward(ctx, hiddenStates)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -112,10 +90,10 @@ type VisionEncoderLayer struct {
|
||||||
MLP *VisionMLP
|
MLP *VisionMLP
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenStates, cos, sin, mask ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenStates, positions, mask ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||||
residual := hiddenStates
|
residual := hiddenStates
|
||||||
hiddenStates = e.Norm1.Forward(ctx, hiddenStates, opts.eps)
|
hiddenStates = e.Norm1.Forward(ctx, hiddenStates, opts.eps)
|
||||||
hiddenStates = e.SelfAttention.Forward(ctx, hiddenStates, cos, sin, mask, opts)
|
hiddenStates = e.SelfAttention.Forward(ctx, hiddenStates, positions, mask, opts)
|
||||||
hiddenStates = hiddenStates.Add(ctx, residual)
|
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||||
|
|
||||||
residual = hiddenStates
|
residual = hiddenStates
|
||||||
|
|
@ -139,6 +117,17 @@ type VisionModelOptions struct {
|
||||||
temporalPatchSize int
|
temporalPatchSize int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (o VisionModelOptions) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
|
||||||
|
return nn.RoPE(ctx, states, positions, o.headDim/2, o.ropeTheta, 1,
|
||||||
|
rope.WithVision([]int{
|
||||||
|
o.headDim / 4,
|
||||||
|
o.headDim / 4,
|
||||||
|
o.headDim / 4,
|
||||||
|
o.headDim / 4,
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
type PatchEmbedding struct {
|
type PatchEmbedding struct {
|
||||||
PatchConv0 *nn.Conv2D `gguf:"patch_embd_0"`
|
PatchConv0 *nn.Conv2D `gguf:"patch_embd_0"`
|
||||||
PatchConv1 *nn.Conv2D `gguf:"patch_embd_1"`
|
PatchConv1 *nn.Conv2D `gguf:"patch_embd_1"`
|
||||||
|
|
@ -186,7 +175,7 @@ func (pm *VisionPatchMerger) Forward(ctx ml.Context, visionOutputs ml.Tensor, op
|
||||||
hiddenSize := visionOutputs.Dim(0) * (opts.spatialMergeSize * opts.spatialMergeSize)
|
hiddenSize := visionOutputs.Dim(0) * (opts.spatialMergeSize * opts.spatialMergeSize)
|
||||||
|
|
||||||
// Reshape the normalized output to view the hidden size dimension
|
// Reshape the normalized output to view the hidden size dimension
|
||||||
reshaped := normalized.Reshape(ctx, hiddenSize, normalized.Dim(1)/(opts.spatialMergeSize*opts.spatialMergeSize), batchSize)
|
reshaped := normalized.Reshape(ctx, hiddenSize, normalized.Dim(1)/(opts.spatialMergeSize*opts.spatialMergeSize))
|
||||||
hidden := pm.MLP0.Forward(ctx, reshaped)
|
hidden := pm.MLP0.Forward(ctx, reshaped)
|
||||||
activated := hidden.GELU(ctx)
|
activated := hidden.GELU(ctx)
|
||||||
|
|
||||||
|
|
@ -209,36 +198,53 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid)
|
||||||
// Extract patch embeddings
|
// Extract patch embeddings
|
||||||
hiddenStates := m.PatchEmbedding.Forward(ctx, pixelValues, m.VisionModelOptions)
|
hiddenStates := m.PatchEmbedding.Forward(ctx, pixelValues, m.VisionModelOptions)
|
||||||
|
|
||||||
positionEmbedding := m.PositionalEmbedding(ctx, grid)
|
index, bounds := m.windowIndex(grid)
|
||||||
|
|
||||||
windowIndex, bounds := m.WindowIndex(ctx, grid)
|
|
||||||
|
|
||||||
spatialMergeUnit := m.spatialMergeSize * m.spatialMergeSize
|
spatialMergeUnit := m.spatialMergeSize * m.spatialMergeSize
|
||||||
|
|
||||||
|
windowIndex := ctx.Input().FromInts(index, len(index))
|
||||||
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0)*spatialMergeUnit, hiddenStates.Dim(1)/spatialMergeUnit)
|
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0)*spatialMergeUnit, hiddenStates.Dim(1)/spatialMergeUnit)
|
||||||
hiddenStates = hiddenStates.Rows(ctx, windowIndex)
|
hiddenStates = hiddenStates.Rows(ctx, windowIndex.Argsort(ctx))
|
||||||
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0)/spatialMergeUnit, hiddenStates.Dim(1)*spatialMergeUnit)
|
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0)/spatialMergeUnit, hiddenStates.Dim(1)*spatialMergeUnit)
|
||||||
|
|
||||||
positionEmbedding = positionEmbedding.Reshape(ctx, positionEmbedding.Dim(0)*spatialMergeUnit, positionEmbedding.Dim(1)/spatialMergeUnit)
|
positions := ctx.Input().FromInts(func() []int32 {
|
||||||
positionEmbedding = positionEmbedding.Rows(ctx, windowIndex)
|
s := [][]int32{
|
||||||
positionEmbedding = positionEmbedding.Reshape(ctx, positionEmbedding.Dim(0)/spatialMergeUnit, positionEmbedding.Dim(1)*spatialMergeUnit)
|
make([]int32, grid.Height*grid.Width),
|
||||||
positionEmbedding = positionEmbedding.Concat(ctx, positionEmbedding, 0)
|
make([]int32, grid.Height*grid.Width),
|
||||||
|
make([]int32, grid.Height*grid.Width),
|
||||||
|
make([]int32, grid.Height*grid.Width),
|
||||||
|
}
|
||||||
|
|
||||||
cos, sin := positionEmbedding.Cos(ctx), positionEmbedding.Sin(ctx)
|
var cur int
|
||||||
cos = cos.Reshape(ctx, cos.Dim(0), 1, cos.Dim(1))
|
for y := 0; y < grid.Height; y += m.spatialMergeSize {
|
||||||
sin = sin.Reshape(ctx, sin.Dim(0), 1, sin.Dim(1))
|
for x := 0; x < grid.Width; x += m.spatialMergeSize {
|
||||||
|
for dy := range 2 {
|
||||||
|
for dx := range 2 {
|
||||||
|
i := int(index[cur/spatialMergeUnit]) * spatialMergeUnit
|
||||||
|
i += cur % spatialMergeUnit
|
||||||
|
s[0][i] = int32(y + dy)
|
||||||
|
s[1][i] = int32(x + dx)
|
||||||
|
s[2][i] = int32(y + dy)
|
||||||
|
s[3][i] = int32(x + dx)
|
||||||
|
cur++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return slices.Concat(s...)
|
||||||
|
}(), grid.Height*grid.Width*4)
|
||||||
|
|
||||||
|
mask := blockDiagonalMask(ctx, hiddenStates.Dim(1), bounds)
|
||||||
|
|
||||||
mask := blockDiagonalMask(ctx, hiddenStates.Dim(1), bounds, m.VisionModelOptions.numHeads)
|
|
||||||
// Apply encoder layers
|
// Apply encoder layers
|
||||||
for i, layer := range m.Layers {
|
for i, layer := range m.Layers {
|
||||||
if slices.Contains(m.fullAttnBlocks, int32(i)) {
|
if slices.Contains(m.fullAttnBlocks, int32(i)) {
|
||||||
hiddenStates = layer.Forward(ctx, hiddenStates, cos, sin, nil, m.VisionModelOptions)
|
hiddenStates = layer.Forward(ctx, hiddenStates, positions, nil, m.VisionModelOptions)
|
||||||
} else {
|
} else {
|
||||||
hiddenStates = layer.Forward(
|
hiddenStates = layer.Forward(
|
||||||
ctx,
|
ctx,
|
||||||
hiddenStates,
|
hiddenStates,
|
||||||
cos,
|
positions,
|
||||||
sin,
|
|
||||||
mask,
|
mask,
|
||||||
m.VisionModelOptions,
|
m.VisionModelOptions,
|
||||||
)
|
)
|
||||||
|
|
@ -246,102 +252,43 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid)
|
||||||
}
|
}
|
||||||
|
|
||||||
hiddenStates = m.PatchMerger.Forward(ctx, hiddenStates, m.VisionModelOptions)
|
hiddenStates = m.PatchMerger.Forward(ctx, hiddenStates, m.VisionModelOptions)
|
||||||
reverseWindowIndex := windowIndex.Argsort(ctx)
|
return hiddenStates.Rows(ctx, windowIndex)
|
||||||
return hiddenStates.Rows(ctx, reverseWindowIndex)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// WindowIndex divides the grid into windows and returns:
|
// windowIndex divides the grid into windows and returns:
|
||||||
// 1. A tensor containing flattened indices of all grid points organized by windows
|
// 1. A slice of grid point indices organized by windows
|
||||||
// 2. A slice of boundaries that mark where each window's data begins and ends
|
// 2. A slice of boundaries that mark where each window's data begins and ends
|
||||||
// in the flattened representation, scaled by spatialMergeSize squared
|
// in the flattened representation, scaled by spatialMergeSize squared
|
||||||
//
|
//
|
||||||
// The boundaries slice always starts with 0 and contains cumulative ending
|
// The boundaries slice always starts with 0 and contains cumulative ending
|
||||||
// positions for each window, allowing downstream processing to identify
|
// positions for each window, allowing downstream processing to identify
|
||||||
// window boundaries in the tensor data.
|
// window boundaries in the tensor data.
|
||||||
func (m *VisionModel) WindowIndex(ctx ml.Context, grid *Grid) (ml.Tensor, []int) {
|
func (m *VisionModel) windowIndex(grid *Grid) (index []int32, bounds []int) {
|
||||||
vitMergerWindowSize := m.windowSize / m.spatialMergeSize / m.patchSize
|
height := grid.Height / m.spatialMergeSize
|
||||||
|
width := grid.Width / m.spatialMergeSize
|
||||||
|
window := m.windowSize / m.patchSize / m.spatialMergeSize
|
||||||
|
|
||||||
llmGridH := grid.Height / m.spatialMergeSize
|
index = make([]int32, height*width)
|
||||||
llmGridW := grid.Width / m.spatialMergeSize
|
|
||||||
|
|
||||||
// Calculate window parameters
|
bounds = make([]int, 0, ((height+window-1)/window)*((width+window-1)/window)+1)
|
||||||
numWindowsH := int(math.Ceil(float64(llmGridH) / float64(vitMergerWindowSize)))
|
bounds = append(bounds, 0)
|
||||||
numWindowsW := int(math.Ceil(float64(llmGridW) / float64(vitMergerWindowSize)))
|
|
||||||
|
|
||||||
// Initialize index_new slice
|
var cur int32
|
||||||
var index []int32
|
for y := 0; y < height; y += window {
|
||||||
|
for x := 0; x < width; x += window {
|
||||||
// Initialize bounds with the first element as 0
|
h1 := min(window, height-y)
|
||||||
bounds := []int{0}
|
w1 := min(window, width-x)
|
||||||
totalSeqLen := 0
|
for dy := range h1 {
|
||||||
|
for dx := range w1 {
|
||||||
// Process each window without padding
|
win := (y+dy)*width + (x + dx)
|
||||||
for wh := range numWindowsH {
|
index[win] = cur
|
||||||
for ww := range numWindowsW {
|
cur++
|
||||||
// Calculate window boundaries
|
|
||||||
hStart := wh * vitMergerWindowSize
|
|
||||||
wStart := ww * vitMergerWindowSize
|
|
||||||
hEnd := min(hStart+vitMergerWindowSize, llmGridH)
|
|
||||||
wEnd := min(wStart+vitMergerWindowSize, llmGridW)
|
|
||||||
|
|
||||||
// Calculate sequence length for this window
|
|
||||||
seqLen := (hEnd - hStart) * (wEnd - wStart)
|
|
||||||
|
|
||||||
// Collect indices for this window
|
|
||||||
for h := hStart; h < hEnd; h++ {
|
|
||||||
for w := wStart; w < wEnd; w++ {
|
|
||||||
index = append(index, int32(h*llmGridW+w))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
bounds = append(bounds, int(cur)*window)
|
||||||
totalSeqLen += seqLen
|
|
||||||
bounds = append(bounds, totalSeqLen*(m.spatialMergeSize*m.spatialMergeSize)+bounds[0])
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return index, bounds
|
||||||
t := ctx.Input().FromInts(index, len(index))
|
|
||||||
|
|
||||||
return t, bounds
|
|
||||||
}
|
|
||||||
|
|
||||||
// PositionalEmbedding generates rotary position embeddings for attention mechanisms
|
|
||||||
func (m *VisionModel) PositionalEmbedding(ctx ml.Context, grid *Grid) ml.Tensor {
|
|
||||||
dim := m.headDim / 2
|
|
||||||
freq := dim / 2
|
|
||||||
theta := float64(m.ropeTheta)
|
|
||||||
merge := m.spatialMergeSize
|
|
||||||
|
|
||||||
// Create frequency patterns for position encoding
|
|
||||||
maxGridSize := max(grid.Height, grid.Width)
|
|
||||||
freqVals := make([]float32, freq*maxGridSize)
|
|
||||||
for i := range maxGridSize {
|
|
||||||
for j := range freq {
|
|
||||||
freqVals[i*freq+j] = float32(i) / float32(math.Pow(theta, float64(j*2)/float64(dim)))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
freqs := ctx.Input().FromFloats(freqVals, freq, maxGridSize)
|
|
||||||
|
|
||||||
// Create position coordinates (y,x pairs) for the grid
|
|
||||||
// In PyTorch: Equivalent to generating position ids with torch.arange()
|
|
||||||
coords := make([]int32, 0, grid.Height*grid.Width*2)
|
|
||||||
for y := range grid.Height {
|
|
||||||
for x := range grid.Width {
|
|
||||||
coords = append(coords, int32(y), int32(x))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
pos := ctx.Input().FromInts(coords, 2, grid.Width, grid.Height)
|
|
||||||
|
|
||||||
// Reshape and permute positions to match spatial merging pattern
|
|
||||||
pos = pos.Reshape(ctx, 2, grid.Width, merge, grid.Height/merge)
|
|
||||||
pos = pos.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
|
||||||
pos = pos.Reshape(ctx, 2, merge, merge, grid.Width/merge*grid.Height/merge)
|
|
||||||
pos = pos.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
|
||||||
pos = pos.Reshape(ctx, 2*merge*merge*grid.Width/merge*grid.Height/merge)
|
|
||||||
|
|
||||||
// Use position indices to look up corresponding frequency values
|
|
||||||
positionalEmbedding := freqs.Rows(ctx, pos)
|
|
||||||
positionalEmbedding = positionalEmbedding.Reshape(ctx, positionalEmbedding.Dim(0)*2, positionalEmbedding.Dim(1)/2)
|
|
||||||
return positionalEmbedding
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// newVisionModel creates a new instance of the Qwen vision model
|
// newVisionModel creates a new instance of the Qwen vision model
|
||||||
|
|
|
||||||
|
|
@ -19,8 +19,8 @@ type ImageProcessor struct {
|
||||||
maxPixels int
|
maxPixels int
|
||||||
factor int
|
factor int
|
||||||
rescaleFactor float32
|
rescaleFactor float32
|
||||||
imageMean []float32
|
imageMean [3]float32
|
||||||
imageStd []float32
|
imageStd [3]float32
|
||||||
}
|
}
|
||||||
|
|
||||||
// newImageProcessor creates a new image processor with default values
|
// newImageProcessor creates a new image processor with default values
|
||||||
|
|
@ -34,11 +34,11 @@ func newImageProcessor(c fs.Config) ImageProcessor {
|
||||||
temporalPatchSize: 2,
|
temporalPatchSize: 2,
|
||||||
mergeSize: mergeSize,
|
mergeSize: mergeSize,
|
||||||
minPixels: 56 * 56,
|
minPixels: 56 * 56,
|
||||||
maxPixels: int(c.Uint("vision.max_pixels", 28*28*1280)), // 1MP limit
|
maxPixels: int(c.Uint("vision.max_pixels", 2<<20)), // 2M limit
|
||||||
factor: patchSize * mergeSize,
|
factor: patchSize * mergeSize,
|
||||||
rescaleFactor: 1.0 / 255.0,
|
rescaleFactor: 1.0 / 255.0,
|
||||||
imageMean: imageproc.ClipDefaultMean[:],
|
imageMean: imageproc.ClipDefaultMean,
|
||||||
imageStd: imageproc.ClipDefaultSTD[:],
|
imageStd: imageproc.ClipDefaultSTD,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -90,13 +90,7 @@ func (p *ImageProcessor) ProcessImage(img image.Image) ([]float32, *Grid, error)
|
||||||
// Resize image using existing functions
|
// Resize image using existing functions
|
||||||
resizedImg := imageproc.Resize(img, image.Point{X: resizedWidth, Y: resizedHeight}, imageproc.ResizeBilinear)
|
resizedImg := imageproc.Resize(img, image.Point{X: resizedWidth, Y: resizedHeight}, imageproc.ResizeBilinear)
|
||||||
|
|
||||||
normalizedPixels := imageproc.Normalize(
|
normalizedPixels := imageproc.Normalize(resizedImg, p.imageMean, p.imageStd, true, true)
|
||||||
resizedImg,
|
|
||||||
[3]float32{p.imageMean[0], p.imageMean[1], p.imageMean[2]},
|
|
||||||
[3]float32{p.imageStd[0], p.imageStd[1], p.imageStd[2]},
|
|
||||||
true, // rescale
|
|
||||||
true, // channelFirst
|
|
||||||
)
|
|
||||||
|
|
||||||
// Calculate grid dimensions
|
// Calculate grid dimensions
|
||||||
grid := &Grid{
|
grid := &Grid{
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,292 @@
|
||||||
|
package parsers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"log/slog"
|
||||||
|
"strings"
|
||||||
|
"unicode"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
type DeepSeek3ParserState int
|
||||||
|
|
||||||
|
const (
|
||||||
|
DeepSeekCollectingThinking DeepSeek3ParserState = iota
|
||||||
|
DeepSeekCollectingContent
|
||||||
|
DeepSeekCollectingToolCalls
|
||||||
|
DeepSeekCollectingToolOutput
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
deepseekThinkingCloseTag = "</think>"
|
||||||
|
deepseekToolCallsBeginTag = "<|tool▁calls▁begin|>"
|
||||||
|
deepseekToolCallsEndTag = "<|tool▁calls▁end|>"
|
||||||
|
deepseekToolCallBeginTag = "<|tool▁call▁begin|>"
|
||||||
|
deepseekToolCallEndTag = "<|tool▁call▁end|>"
|
||||||
|
deepseekToolSepTag = "<|tool▁sep|>"
|
||||||
|
deepseekToolOutputBeginTag = "<|tool▁output▁begin|>"
|
||||||
|
deepseekToolOutputEndTag = "<|tool▁output▁end|>"
|
||||||
|
)
|
||||||
|
|
||||||
|
type DeepSeek3Parser struct {
|
||||||
|
state DeepSeek3ParserState
|
||||||
|
buffer strings.Builder
|
||||||
|
hasThinkingSupport bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *DeepSeek3Parser) HasToolSupport() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *DeepSeek3Parser) HasThinkingSupport() bool {
|
||||||
|
return p.hasThinkingSupport
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *DeepSeek3Parser) setInitialState(lastMessage *api.Message, tools []api.Tool, thinkValue *api.ThinkValue) {
|
||||||
|
prefill := lastMessage != nil && lastMessage.Role == "assistant"
|
||||||
|
|
||||||
|
// Check both model capability AND request preference
|
||||||
|
thinkingEnabled := p.HasThinkingSupport() && (thinkValue != nil && thinkValue.Bool())
|
||||||
|
|
||||||
|
if !thinkingEnabled {
|
||||||
|
p.state = DeepSeekCollectingContent
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if prefill && lastMessage.Content != "" {
|
||||||
|
p.state = DeepSeekCollectingContent
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
p.state = DeepSeekCollectingThinking
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *DeepSeek3Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||||
|
p.setInitialState(lastMessage, tools, thinkValue)
|
||||||
|
return tools
|
||||||
|
}
|
||||||
|
|
||||||
|
type deepseekEvent interface {
|
||||||
|
isDeepSeekEvent()
|
||||||
|
}
|
||||||
|
|
||||||
|
type deepseekEventThinkingContent struct {
|
||||||
|
content string
|
||||||
|
}
|
||||||
|
|
||||||
|
type deepseekEventContent struct {
|
||||||
|
content string
|
||||||
|
}
|
||||||
|
|
||||||
|
type deepseekEventToolCall struct {
|
||||||
|
toolCall api.ToolCall
|
||||||
|
}
|
||||||
|
|
||||||
|
func (deepseekEventThinkingContent) isDeepSeekEvent() {}
|
||||||
|
func (deepseekEventContent) isDeepSeekEvent() {}
|
||||||
|
func (deepseekEventToolCall) isDeepSeekEvent() {}
|
||||||
|
|
||||||
|
func (p *DeepSeek3Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
|
||||||
|
p.buffer.WriteString(s)
|
||||||
|
events := p.parseEvents()
|
||||||
|
|
||||||
|
var toolCalls []api.ToolCall
|
||||||
|
var contentSb strings.Builder
|
||||||
|
var thinkingSb strings.Builder
|
||||||
|
for _, event := range events {
|
||||||
|
switch event := event.(type) {
|
||||||
|
case deepseekEventToolCall:
|
||||||
|
toolCalls = append(toolCalls, event.toolCall)
|
||||||
|
case deepseekEventThinkingContent:
|
||||||
|
thinkingSb.WriteString(event.content)
|
||||||
|
case deepseekEventContent:
|
||||||
|
contentSb.WriteString(event.content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return contentSb.String(), thinkingSb.String(), toolCalls, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *DeepSeek3Parser) parseEvents() []deepseekEvent {
|
||||||
|
var all []deepseekEvent
|
||||||
|
|
||||||
|
keepLooping := true
|
||||||
|
for keepLooping {
|
||||||
|
var events []deepseekEvent
|
||||||
|
events, keepLooping = p.eat()
|
||||||
|
if len(events) > 0 {
|
||||||
|
all = append(all, events...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return all
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *DeepSeek3Parser) eat() ([]deepseekEvent, bool) {
|
||||||
|
var events []deepseekEvent
|
||||||
|
bufStr := p.buffer.String()
|
||||||
|
if bufStr == "" {
|
||||||
|
return events, false
|
||||||
|
}
|
||||||
|
|
||||||
|
switch p.state {
|
||||||
|
case DeepSeekCollectingThinking:
|
||||||
|
if strings.Contains(bufStr, deepseekThinkingCloseTag) { // thinking[</think>] -> content
|
||||||
|
split := strings.SplitN(bufStr, deepseekThinkingCloseTag, 2)
|
||||||
|
thinking := split[0]
|
||||||
|
thinking = strings.TrimRightFunc(thinking, unicode.IsSpace)
|
||||||
|
|
||||||
|
remaining := split[1]
|
||||||
|
remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace)
|
||||||
|
|
||||||
|
p.buffer.Reset()
|
||||||
|
p.buffer.WriteString(remaining)
|
||||||
|
p.state = DeepSeekCollectingContent
|
||||||
|
|
||||||
|
if len(thinking) > 0 {
|
||||||
|
events = append(events, deepseekEventThinkingContent{content: thinking})
|
||||||
|
}
|
||||||
|
return events, true
|
||||||
|
} else if overlapLen := overlap(bufStr, deepseekThinkingCloseTag); overlapLen > 0 { // partial </think>
|
||||||
|
beforePartialTag := bufStr[:len(bufStr)-overlapLen]
|
||||||
|
trailingLen := trailingWhitespaceLen(beforePartialTag)
|
||||||
|
ambiguousStart := len(beforePartialTag) - trailingLen
|
||||||
|
|
||||||
|
unambiguous := bufStr[:ambiguousStart]
|
||||||
|
ambiguous := bufStr[ambiguousStart:]
|
||||||
|
p.buffer.Reset()
|
||||||
|
p.buffer.WriteString(ambiguous)
|
||||||
|
if len(unambiguous) > 0 {
|
||||||
|
events = append(events, deepseekEventThinkingContent{content: unambiguous})
|
||||||
|
}
|
||||||
|
return events, false
|
||||||
|
} else { // otherwise its thinking content
|
||||||
|
whitespaceLen := trailingWhitespaceLen(bufStr)
|
||||||
|
ambiguousStart := len(bufStr) - whitespaceLen
|
||||||
|
|
||||||
|
unambiguous := bufStr[:ambiguousStart]
|
||||||
|
ambiguous := bufStr[ambiguousStart:]
|
||||||
|
p.buffer.Reset()
|
||||||
|
p.buffer.WriteString(ambiguous)
|
||||||
|
if len(unambiguous) > 0 {
|
||||||
|
events = append(events, deepseekEventThinkingContent{content: unambiguous})
|
||||||
|
}
|
||||||
|
return events, false
|
||||||
|
}
|
||||||
|
|
||||||
|
case DeepSeekCollectingContent:
|
||||||
|
switch {
|
||||||
|
case strings.Contains(bufStr, deepseekToolCallsBeginTag): // content[<|tool▁calls▁begin|>] -> tool calls
|
||||||
|
split := strings.SplitN(bufStr, deepseekToolCallsBeginTag, 2)
|
||||||
|
contentBefore := strings.TrimRightFunc(split[0], unicode.IsSpace)
|
||||||
|
remaining := split[1]
|
||||||
|
|
||||||
|
p.buffer.Reset()
|
||||||
|
p.buffer.WriteString(remaining)
|
||||||
|
p.state = DeepSeekCollectingToolCalls
|
||||||
|
|
||||||
|
if len(contentBefore) > 0 {
|
||||||
|
events = append(events, deepseekEventContent{content: contentBefore})
|
||||||
|
}
|
||||||
|
return events, true
|
||||||
|
case strings.Contains(bufStr, deepseekToolOutputBeginTag): // content[<|tool▁output▁begin|>] -> tool output
|
||||||
|
split := strings.SplitN(bufStr, deepseekToolOutputBeginTag, 2)
|
||||||
|
contentBefore := split[0] // Don't trim whitespace - preserve spaces
|
||||||
|
remaining := split[1]
|
||||||
|
|
||||||
|
p.buffer.Reset()
|
||||||
|
p.buffer.WriteString(remaining)
|
||||||
|
p.state = DeepSeekCollectingToolOutput
|
||||||
|
|
||||||
|
if len(contentBefore) > 0 {
|
||||||
|
events = append(events, deepseekEventContent{content: contentBefore})
|
||||||
|
}
|
||||||
|
return events, true
|
||||||
|
default: // otherwise its content
|
||||||
|
p.buffer.Reset()
|
||||||
|
if len(bufStr) > 0 {
|
||||||
|
events = append(events, deepseekEventContent{content: bufStr})
|
||||||
|
}
|
||||||
|
return events, false
|
||||||
|
}
|
||||||
|
|
||||||
|
case DeepSeekCollectingToolCalls:
|
||||||
|
if idx := strings.Index(bufStr, deepseekToolCallBeginTag); idx != -1 {
|
||||||
|
startIdx := idx + len(deepseekToolCallBeginTag)
|
||||||
|
if endIdx := strings.Index(bufStr[startIdx:], deepseekToolCallEndTag); endIdx != -1 {
|
||||||
|
toolCallContent := bufStr[startIdx : startIdx+endIdx]
|
||||||
|
|
||||||
|
if toolCall, err := p.parseToolCallContent(toolCallContent); err == nil {
|
||||||
|
remaining := bufStr[startIdx+endIdx+len(deepseekToolCallEndTag):]
|
||||||
|
remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace)
|
||||||
|
|
||||||
|
p.buffer.Reset()
|
||||||
|
p.buffer.WriteString(remaining)
|
||||||
|
|
||||||
|
events = append(events, deepseekEventToolCall{toolCall: toolCall})
|
||||||
|
return events, true
|
||||||
|
} else {
|
||||||
|
slog.Warn("deepseek tool call parsing failed", "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if idx := strings.Index(bufStr, deepseekToolCallsEndTag); idx != -1 {
|
||||||
|
remaining := bufStr[idx+len(deepseekToolCallsEndTag):]
|
||||||
|
remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace)
|
||||||
|
|
||||||
|
p.buffer.Reset()
|
||||||
|
p.buffer.WriteString(remaining)
|
||||||
|
p.state = DeepSeekCollectingContent
|
||||||
|
|
||||||
|
return events, true
|
||||||
|
}
|
||||||
|
|
||||||
|
return events, false
|
||||||
|
|
||||||
|
case DeepSeekCollectingToolOutput:
|
||||||
|
if idx := strings.Index(bufStr, deepseekToolOutputEndTag); idx != -1 {
|
||||||
|
toolOutputContent := bufStr[:idx]
|
||||||
|
remaining := bufStr[idx+len(deepseekToolOutputEndTag):]
|
||||||
|
// Don't trim whitespace - preserve spaces after tool output tags
|
||||||
|
|
||||||
|
p.buffer.Reset()
|
||||||
|
p.buffer.WriteString(remaining)
|
||||||
|
p.state = DeepSeekCollectingContent
|
||||||
|
|
||||||
|
if len(toolOutputContent) > 0 {
|
||||||
|
events = append(events, deepseekEventContent{content: toolOutputContent})
|
||||||
|
}
|
||||||
|
return events, true
|
||||||
|
}
|
||||||
|
|
||||||
|
return events, false
|
||||||
|
}
|
||||||
|
|
||||||
|
return events, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *DeepSeek3Parser) parseToolCallContent(content string) (api.ToolCall, error) {
|
||||||
|
// Expected format: tool_name<|tool▁sep|>{args}
|
||||||
|
parts := strings.SplitN(content, deepseekToolSepTag, 2)
|
||||||
|
if len(parts) < 2 {
|
||||||
|
return api.ToolCall{}, errors.New("invalid format")
|
||||||
|
}
|
||||||
|
|
||||||
|
toolName := strings.TrimSpace(parts[0])
|
||||||
|
argsJSON := strings.TrimSpace(parts[1])
|
||||||
|
|
||||||
|
var args api.ToolCallFunctionArguments
|
||||||
|
if err := json.Unmarshal([]byte(argsJSON), &args); err != nil {
|
||||||
|
return api.ToolCall{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return api.ToolCall{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: toolName,
|
||||||
|
Arguments: args,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,721 @@
|
||||||
|
package parsers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDeepSeekParser(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expectedContent string
|
||||||
|
expectedThinking string
|
||||||
|
expectedCalls []api.ToolCall
|
||||||
|
hasThinking bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "simple_content",
|
||||||
|
input: "Hello, how are you?",
|
||||||
|
expectedContent: "Hello, how are you?",
|
||||||
|
hasThinking: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "thinking_content",
|
||||||
|
input: "I need to think about this...</think>The answer is 42.",
|
||||||
|
expectedThinking: "I need to think about this...",
|
||||||
|
expectedContent: "The answer is 42.",
|
||||||
|
hasThinking: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no_thinking_simple",
|
||||||
|
input: "Just a regular response.",
|
||||||
|
expectedContent: "Just a regular response.",
|
||||||
|
hasThinking: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "thinking_with_newlines",
|
||||||
|
input: "Let me think:\n- Point 1\n- Point 2</think>\n\nHere's my answer.",
|
||||||
|
expectedThinking: "Let me think:\n- Point 1\n- Point 2",
|
||||||
|
expectedContent: "Here's my answer.",
|
||||||
|
hasThinking: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tool_call_simple",
|
||||||
|
input: "I'll check the weather.<|tool▁calls▁begin|><|tool▁call▁begin|>get_weather<|tool▁sep|>{\"location\":\"Paris\"}<|tool▁call▁end|><|tool▁calls▁end|>",
|
||||||
|
expectedContent: "I'll check the weather.",
|
||||||
|
expectedCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Arguments: api.ToolCallFunctionArguments{
|
||||||
|
"location": "Paris",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
hasThinking: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple_tool_calls",
|
||||||
|
input: "Getting weather for both cities.<|tool▁calls▁begin|><|tool▁call▁begin|>get_weather<|tool▁sep|>{\"location\":\"Paris\"}<|tool▁call▁end|><|tool▁call▁begin|>get_weather<|tool▁sep|>{\"location\":\"London\"}<|tool▁call▁end|><|tool▁calls▁end|>",
|
||||||
|
expectedContent: "Getting weather for both cities.",
|
||||||
|
expectedCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Arguments: api.ToolCallFunctionArguments{
|
||||||
|
"location": "Paris",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Arguments: api.ToolCallFunctionArguments{
|
||||||
|
"location": "London",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
hasThinking: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tool_output",
|
||||||
|
input: "Here's the weather: <|tool▁output▁begin|>Temperature: 22°C, Sunny<|tool▁output▁end|> Hope that helps!",
|
||||||
|
expectedContent: "Here's the weather: Temperature: 22°C, Sunny Hope that helps!",
|
||||||
|
hasThinking: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "complex_tool_arguments",
|
||||||
|
input: "Processing data.<|tool▁calls▁begin|><|tool▁call▁begin|>process_data<|tool▁sep|>{\"items\":[\"item1\",\"item2\"],\"config\":{\"enabled\":true,\"threshold\":0.95}}<|tool▁call▁end|><|tool▁calls▁end|>",
|
||||||
|
expectedContent: "Processing data.",
|
||||||
|
expectedCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "process_data",
|
||||||
|
Arguments: api.ToolCallFunctionArguments{
|
||||||
|
"items": []interface{}{"item1", "item2"},
|
||||||
|
"config": map[string]interface{}{"enabled": true, "threshold": 0.95},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
hasThinking: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "thinking_with_tool_call", // technically this can't happen, but the parser can handle it
|
||||||
|
input: "Let me check the weather...</think>I'll get that for you.<|tool▁calls▁begin|><|tool▁call▁begin|>get_weather<|tool▁sep|>{\"location\":\"Paris\"}<|tool▁call▁end|><|tool▁calls▁end|>",
|
||||||
|
expectedThinking: "Let me check the weather...",
|
||||||
|
expectedContent: "I'll get that for you.",
|
||||||
|
expectedCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Arguments: api.ToolCallFunctionArguments{
|
||||||
|
"location": "Paris",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
hasThinking: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty_content",
|
||||||
|
input: "",
|
||||||
|
expectedContent: "",
|
||||||
|
hasThinking: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "only_thinking",
|
||||||
|
input: "Just thinking content</think>",
|
||||||
|
expectedThinking: "Just thinking content",
|
||||||
|
expectedContent: "",
|
||||||
|
hasThinking: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple_tool_outputs",
|
||||||
|
input: "Results: <|tool▁output▁begin|>Paris: 22°C<|tool▁output▁end|> and <|tool▁output▁begin|>London: 18°C<|tool▁output▁end|>",
|
||||||
|
expectedContent: "Results: Paris: 22°C and London: 18°C",
|
||||||
|
hasThinking: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unicode_content",
|
||||||
|
input: "مرحبا بالعالم! 你好世界! 🌍",
|
||||||
|
expectedContent: "مرحبا بالعالم! 你好世界! 🌍",
|
||||||
|
hasThinking: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "emoji_passthrough",
|
||||||
|
input: "Task completed ✅ 🎉",
|
||||||
|
expectedContent: "Task completed ✅ 🎉",
|
||||||
|
hasThinking: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "emoji_after_tool_call",
|
||||||
|
input: "I'll help you.<|tool▁calls▁begin|><|tool▁call▁begin|>get_weather<|tool▁sep|>{\"location\":\"Tokyo\"}<|tool▁call▁end|><|tool▁calls▁end|>完成 ✅",
|
||||||
|
expectedContent: "I'll help you.完成 ✅",
|
||||||
|
expectedCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Arguments: api.ToolCallFunctionArguments{
|
||||||
|
"location": "Tokyo",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
hasThinking: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "newlines_and_whitespace",
|
||||||
|
input: "Line 1\n\nLine 3\t\tTabbed content",
|
||||||
|
expectedContent: "Line 1\n\nLine 3\t\tTabbed content",
|
||||||
|
hasThinking: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "thinking_with_unicode",
|
||||||
|
input: "我在思考这个问题...</think>答案是42。",
|
||||||
|
expectedThinking: "我在思考这个问题...",
|
||||||
|
expectedContent: "答案是42。",
|
||||||
|
hasThinking: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tool_call_with_unicode_args",
|
||||||
|
input: "Searching for information.<|tool▁calls▁begin|><|tool▁call▁begin|>search<|tool▁sep|>{\"query\":\"北京天气\",\"language\":\"中文\"}<|tool▁call▁end|><|tool▁calls▁end|>",
|
||||||
|
expectedContent: "Searching for information.",
|
||||||
|
expectedCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "search",
|
||||||
|
Arguments: api.ToolCallFunctionArguments{
|
||||||
|
"query": "北京天气",
|
||||||
|
"language": "中文",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
hasThinking: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tool_output_with_unicode",
|
||||||
|
input: "天气信息: <|tool▁output▁begin|>北京: 25°C, 晴天<|tool▁output▁end|> 希望对您有帮助!",
|
||||||
|
expectedContent: "天气信息: 北京: 25°C, 晴天 希望对您有帮助!",
|
||||||
|
hasThinking: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mixed_content_with_special_chars",
|
||||||
|
input: "Price: $100 & tax @ 10% = $110 <|tool▁output▁begin|>Total: $110<|tool▁output▁end|> (final)",
|
||||||
|
expectedContent: "Price: $100 & tax @ 10% = $110 Total: $110 (final)",
|
||||||
|
hasThinking: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tool_call_with_special_chars",
|
||||||
|
input: "Processing data.<|tool▁calls▁begin|><|tool▁call▁begin|>execute_command<|tool▁sep|>{\"command\":\"ls && echo \\\"done\\\"\",\"path\":\"/home/user\"}<|tool▁call▁end|><|tool▁calls▁end|>",
|
||||||
|
expectedContent: "Processing data.",
|
||||||
|
expectedCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "execute_command",
|
||||||
|
Arguments: api.ToolCallFunctionArguments{
|
||||||
|
"command": "ls && echo \"done\"",
|
||||||
|
"path": "/home/user",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
hasThinking: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "thinking_with_special_chars",
|
||||||
|
input: "Let me calculate: 2+2=4 & 3*3=9...</think>The results are correct!",
|
||||||
|
expectedThinking: "Let me calculate: 2+2=4 & 3*3=9...",
|
||||||
|
expectedContent: "The results are correct!",
|
||||||
|
hasThinking: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty_tool_call_args",
|
||||||
|
input: "Pinging server.<|tool▁calls▁begin|><|tool▁call▁begin|>ping<|tool▁sep|>{}<|tool▁call▁end|><|tool▁calls▁end|>",
|
||||||
|
expectedContent: "Pinging server.",
|
||||||
|
expectedCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "ping",
|
||||||
|
Arguments: api.ToolCallFunctionArguments{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
hasThinking: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty_tool_output",
|
||||||
|
input: "Checking status: <|tool▁output▁begin|><|tool▁output▁end|> No output received.",
|
||||||
|
expectedContent: "Checking status: No output received.",
|
||||||
|
hasThinking: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
parser := &DeepSeek3Parser{hasThinkingSupport: tt.hasThinking}
|
||||||
|
parser.Init([]api.Tool{}, nil, &api.ThinkValue{Value: tt.hasThinking})
|
||||||
|
|
||||||
|
content, thinking, calls, err := parser.Add(tt.input, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Add() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(tt.expectedContent, content); diff != "" {
|
||||||
|
t.Errorf("Content mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(tt.expectedThinking, thinking); diff != "" {
|
||||||
|
t.Errorf("Thinking mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(tt.expectedCalls, calls); diff != "" {
|
||||||
|
t.Errorf("Tool calls mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeepSeekParser_Streaming(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
chunks []string
|
||||||
|
expectedContent string
|
||||||
|
expectedThinking string
|
||||||
|
expectedCalls []api.ToolCall
|
||||||
|
hasThinking bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "streaming_simple_content",
|
||||||
|
chunks: []string{"Hello, ", "how are ", "you?"},
|
||||||
|
expectedContent: "Hello, how are you?",
|
||||||
|
hasThinking: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "streaming_thinking",
|
||||||
|
chunks: []string{"I need to ", "think about this", "...</think>", "The answer is 42."},
|
||||||
|
expectedThinking: "I need to think about this...",
|
||||||
|
expectedContent: "The answer is 42.",
|
||||||
|
hasThinking: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "streaming_tool_call",
|
||||||
|
chunks: []string{"I'll check weather.", "<|tool▁calls▁begin|>", "<|tool▁call▁begin|>get_weather", "<|tool▁sep|>{\"location\":\"Paris\"}", "<|tool▁call▁end|><|tool▁calls▁end|>"},
|
||||||
|
expectedContent: "I'll check weather.",
|
||||||
|
expectedCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Arguments: api.ToolCallFunctionArguments{
|
||||||
|
"location": "Paris",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
hasThinking: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "streaming_thinking_with_partial_tag",
|
||||||
|
chunks: []string{"Thinking about this", "...</", "think>", "Done thinking."},
|
||||||
|
expectedThinking: "Thinking about this...",
|
||||||
|
expectedContent: "Done thinking.",
|
||||||
|
hasThinking: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "streaming_tool_output",
|
||||||
|
chunks: []string{"Weather info: ", "<|tool▁output▁begin|>", "25°C, Sunny", "<|tool▁output▁end|>", " Enjoy!"},
|
||||||
|
expectedContent: "Weather info: 25°C, Sunny Enjoy!",
|
||||||
|
hasThinking: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "streaming_with_split_tags",
|
||||||
|
chunks: []string{"Content before ", "<|tool▁calls▁begin|><|tool▁call▁begin|>test", "<|tool▁sep|>{}", "<|tool▁call▁end|><|tool▁calls▁end|>", " after"},
|
||||||
|
expectedContent: "Content before after",
|
||||||
|
expectedCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "test",
|
||||||
|
Arguments: api.ToolCallFunctionArguments{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
hasThinking: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "streaming_thinking_with_split_end_tag",
|
||||||
|
chunks: []string{"Thinking content", "</th", "ink>", "Regular content"},
|
||||||
|
expectedThinking: "Thinking content",
|
||||||
|
expectedContent: "Regular content",
|
||||||
|
hasThinking: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "streaming_unicode_content",
|
||||||
|
chunks: []string{"مرحبا ", "بالعالم! ", "你好", "世界!"},
|
||||||
|
expectedContent: "مرحبا بالعالم! 你好世界!",
|
||||||
|
hasThinking: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "streaming_multiple_tool_outputs",
|
||||||
|
chunks: []string{"Results: ", "<|tool▁output▁begin|>", "Paris: 22°C", "<|tool▁output▁end|>", " and ", "<|tool▁output▁begin|>", "London: 18°C", "<|tool▁output▁end|>"},
|
||||||
|
expectedContent: "Results: Paris: 22°C and London: 18°C",
|
||||||
|
hasThinking: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "streaming_tool_call_with_split_json",
|
||||||
|
chunks: []string{"Processing.", "<|tool▁calls▁begin|><|tool▁call▁begin|>calc<|tool▁sep|>{\"x\":", "42,\"y\":", "24}<|tool▁call▁end|><|tool▁calls▁end|>"},
|
||||||
|
expectedContent: "Processing.",
|
||||||
|
expectedCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "calc",
|
||||||
|
Arguments: api.ToolCallFunctionArguments{
|
||||||
|
"x": float64(42),
|
||||||
|
"y": float64(24),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
hasThinking: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
parser := &DeepSeek3Parser{hasThinkingSupport: tt.hasThinking}
|
||||||
|
parser.Init([]api.Tool{}, nil, &api.ThinkValue{Value: tt.hasThinking})
|
||||||
|
|
||||||
|
var allContent, allThinking string
|
||||||
|
var allCalls []api.ToolCall
|
||||||
|
|
||||||
|
for i, chunk := range tt.chunks {
|
||||||
|
done := i == len(tt.chunks)-1
|
||||||
|
content, thinking, calls, err := parser.Add(chunk, done)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Add() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
allContent += content
|
||||||
|
allThinking += thinking
|
||||||
|
allCalls = append(allCalls, calls...)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(tt.expectedContent, allContent); diff != "" {
|
||||||
|
t.Errorf("Content mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(tt.expectedThinking, allThinking); diff != "" {
|
||||||
|
t.Errorf("Thinking mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(tt.expectedCalls, allCalls); diff != "" {
|
||||||
|
t.Errorf("Tool calls mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeepSeekParser_HasThinkingSupport(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
hasThinking bool
|
||||||
|
expectedSupport bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "thinking_enabled",
|
||||||
|
hasThinking: true,
|
||||||
|
expectedSupport: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "thinking_disabled",
|
||||||
|
hasThinking: false,
|
||||||
|
expectedSupport: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
parser := &DeepSeek3Parser{hasThinkingSupport: tt.hasThinking}
|
||||||
|
if got := parser.HasThinkingSupport(); got != tt.expectedSupport {
|
||||||
|
t.Errorf("HasThinkingSupport() = %v, want %v", got, tt.expectedSupport)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeepSeekParser_HasToolSupport(t *testing.T) {
|
||||||
|
parser := &DeepSeek3Parser{}
|
||||||
|
if !parser.HasToolSupport() {
|
||||||
|
t.Error("HasToolSupport() should return true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeepSeekParser_Init(t *testing.T) {
|
||||||
|
parser := &DeepSeek3Parser{hasThinkingSupport: true}
|
||||||
|
tools := []api.Tool{
|
||||||
|
{
|
||||||
|
Type: "function",
|
||||||
|
Function: api.ToolFunction{
|
||||||
|
Name: "test_tool",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
returnedTools := parser.Init(tools, nil, &api.ThinkValue{Value: true})
|
||||||
|
|
||||||
|
if diff := cmp.Diff(tools, returnedTools); diff != "" {
|
||||||
|
t.Errorf("Init() returned tools mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test initial state is set to thinking when enabled
|
||||||
|
if parser.state != DeepSeekCollectingThinking {
|
||||||
|
t.Errorf("Expected initial state to be DeepSeekCollectingThinking, got %v", parser.state)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeepSeek3Parser_parseToolCallContent(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
content string
|
||||||
|
expected api.ToolCall
|
||||||
|
expectError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid_tool_call",
|
||||||
|
content: "get_weather<|tool▁sep|>{\"location\":\"Paris\"}",
|
||||||
|
expected: api.ToolCall{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Arguments: api.ToolCallFunctionArguments{
|
||||||
|
"location": "Paris",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "complex_arguments",
|
||||||
|
content: "process_data<|tool▁sep|>{\"items\":[\"a\",\"b\"],\"config\":{\"enabled\":true}}",
|
||||||
|
expected: api.ToolCall{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "process_data",
|
||||||
|
Arguments: api.ToolCallFunctionArguments{
|
||||||
|
"items": []interface{}{"a", "b"},
|
||||||
|
"config": map[string]interface{}{"enabled": true},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty_arguments",
|
||||||
|
content: "ping<|tool▁sep|>{}",
|
||||||
|
expected: api.ToolCall{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "ping",
|
||||||
|
Arguments: api.ToolCallFunctionArguments{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unicode_in_tool_name",
|
||||||
|
content: "获取天气<|tool▁sep|>{\"城市\":\"北京\"}",
|
||||||
|
expected: api.ToolCall{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "获取天气",
|
||||||
|
Arguments: api.ToolCallFunctionArguments{
|
||||||
|
"城市": "北京",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "special_chars_in_arguments",
|
||||||
|
content: "execute<|tool▁sep|>{\"command\":\"ls && echo \\\"done\\\"\",\"path\":\"/home/user\"}",
|
||||||
|
expected: api.ToolCall{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "execute",
|
||||||
|
Arguments: api.ToolCallFunctionArguments{
|
||||||
|
"command": "ls && echo \"done\"",
|
||||||
|
"path": "/home/user",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "numeric_arguments",
|
||||||
|
content: "calculate<|tool▁sep|>{\"x\":3.14,\"y\":42,\"enabled\":true}",
|
||||||
|
expected: api.ToolCall{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "calculate",
|
||||||
|
Arguments: api.ToolCallFunctionArguments{
|
||||||
|
"x": 3.14,
|
||||||
|
"y": float64(42),
|
||||||
|
"enabled": true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid_format_no_separator",
|
||||||
|
content: "get_weather{\"location\":\"Paris\"}",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid_json",
|
||||||
|
content: "get_weather<|tool▁sep|>{invalid json}",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty_tool_name",
|
||||||
|
content: "<|tool▁sep|>{\"arg\":\"value\"}",
|
||||||
|
expectError: false, // This should work, just empty name
|
||||||
|
expected: api.ToolCall{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "",
|
||||||
|
Arguments: api.ToolCallFunctionArguments{
|
||||||
|
"arg": "value",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing_json_part",
|
||||||
|
content: "tool_name<|tool▁sep|>",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
parser := &DeepSeek3Parser{}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result, err := parser.parseToolCallContent(tt.content)
|
||||||
|
|
||||||
|
if tt.expectError {
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected error but got none")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(tt.expected, result); diff != "" {
|
||||||
|
t.Errorf("parseToolCallContent() mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeepSeekParser_EdgeCases(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expectedContent string
|
||||||
|
expectedThinking string
|
||||||
|
hasThinking bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nested_think_tags_in_thinking",
|
||||||
|
input: "Outer thinking <think>inner</think> content</think>Final content",
|
||||||
|
expectedThinking: "Outer thinking <think>inner",
|
||||||
|
expectedContent: "content</think>Final content",
|
||||||
|
hasThinking: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple_think_close_tags",
|
||||||
|
input: "First thought</think>Second thought</think>Final content",
|
||||||
|
expectedThinking: "First thought",
|
||||||
|
expectedContent: "Second thought</think>Final content",
|
||||||
|
hasThinking: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty_thinking_content",
|
||||||
|
input: "</think>Just content",
|
||||||
|
expectedThinking: "",
|
||||||
|
expectedContent: "Just content",
|
||||||
|
hasThinking: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "thinking_disabled_with_think_tags",
|
||||||
|
input: "Some content</think>More content",
|
||||||
|
expectedContent: "Some content</think>More content",
|
||||||
|
hasThinking: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "malformed_tool_call_missing_sep",
|
||||||
|
input: "Testing.<|tool▁calls▁begin|><|tool▁call▁begin|>bad_tool{\"arg\":\"value\"}<|tool▁call▁end|><|tool▁calls▁end|>",
|
||||||
|
expectedContent: "Testing.",
|
||||||
|
hasThinking: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "malformed_tool_call_invalid_json",
|
||||||
|
input: "Testing.<|tool▁calls▁begin|><|tool▁call▁begin|>bad_tool<|tool▁sep|>{invalid json}<|tool▁call▁end|><|tool▁calls▁end|>",
|
||||||
|
expectedContent: "Testing.",
|
||||||
|
hasThinking: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "partial_tool_tag_at_end",
|
||||||
|
input: "Content with partial <|tool▁calls▁",
|
||||||
|
expectedContent: "Content with partial <|tool▁calls▁",
|
||||||
|
hasThinking: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "partial_think_tag_at_end",
|
||||||
|
input: "Thinking content</th",
|
||||||
|
expectedContent: "Thinking content</th",
|
||||||
|
hasThinking: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "partial_think_tag_at_end_with_thinking",
|
||||||
|
input: "Thinking content</th",
|
||||||
|
expectedThinking: "Thinking content",
|
||||||
|
expectedContent: "",
|
||||||
|
hasThinking: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "whitespace_only_content",
|
||||||
|
input: " \n\t ",
|
||||||
|
expectedContent: " \n\t ",
|
||||||
|
hasThinking: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tool_output_with_newlines",
|
||||||
|
input: "Output:\n<|tool▁output▁begin|>Line 1\nLine 2\nLine 3<|tool▁output▁end|>\nDone.",
|
||||||
|
expectedContent: "Output:\nLine 1\nLine 2\nLine 3\nDone.",
|
||||||
|
hasThinking: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "consecutive_tool_calls",
|
||||||
|
input: "First.<|tool▁calls▁begin|><|tool▁call▁begin|>tool1<|tool▁sep|>{}<|tool▁call▁end|><|tool▁calls▁end|>Second.<|tool▁calls▁begin|><|tool▁call▁begin|>tool2<|tool▁sep|>{}<|tool▁call▁end|><|tool▁calls▁end|>",
|
||||||
|
expectedContent: "First.",
|
||||||
|
hasThinking: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
parser := &DeepSeek3Parser{hasThinkingSupport: tt.hasThinking}
|
||||||
|
parser.Init([]api.Tool{}, nil, &api.ThinkValue{Value: tt.hasThinking})
|
||||||
|
|
||||||
|
content, thinking, _, err := parser.Add(tt.input, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Add() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(tt.expectedContent, content); diff != "" {
|
||||||
|
t.Errorf("Content mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(tt.expectedThinking, thinking); diff != "" {
|
||||||
|
t.Errorf("Thinking mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,254 @@
|
||||||
|
package parsers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
"unicode"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Nemotron3NanoParserState int
|
||||||
|
|
||||||
|
const (
|
||||||
|
Nemotron3NanoCollectingThinking Nemotron3NanoParserState = iota
|
||||||
|
Nemotron3NanoSkipWhitespaceAfterThinking
|
||||||
|
Nemotron3NanoCollectingContent
|
||||||
|
Nemotron3NanoCollectingToolCalls
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
nemotronThinkClose = "</think>"
|
||||||
|
nemotronToolCallOpen = "<tool_call>"
|
||||||
|
nemotronToolCallClose = "</tool_call>"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Nemotron3NanoParser struct {
|
||||||
|
state Nemotron3NanoParserState
|
||||||
|
buffer strings.Builder
|
||||||
|
tools []api.Tool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Nemotron3NanoParser) HasToolSupport() bool { return true }
|
||||||
|
func (p *Nemotron3NanoParser) HasThinkingSupport() bool { return true }
|
||||||
|
|
||||||
|
func (p *Nemotron3NanoParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||||
|
p.tools = tools
|
||||||
|
|
||||||
|
// thinking is enabled if user requests it
|
||||||
|
thinkingEnabled := thinkValue != nil && thinkValue.Bool()
|
||||||
|
|
||||||
|
prefill := lastMessage != nil && lastMessage.Role == "assistant"
|
||||||
|
|
||||||
|
if !thinkingEnabled {
|
||||||
|
p.state = Nemotron3NanoCollectingContent
|
||||||
|
return tools
|
||||||
|
}
|
||||||
|
|
||||||
|
if prefill && lastMessage.Content != "" {
|
||||||
|
p.state = Nemotron3NanoCollectingContent
|
||||||
|
return tools
|
||||||
|
}
|
||||||
|
|
||||||
|
p.state = Nemotron3NanoCollectingThinking
|
||||||
|
return tools
|
||||||
|
}
|
||||||
|
|
||||||
|
type nemotronEvent interface {
|
||||||
|
isNemotronEvent()
|
||||||
|
}
|
||||||
|
|
||||||
|
type nemotronEventThinkingContent struct {
|
||||||
|
content string
|
||||||
|
}
|
||||||
|
|
||||||
|
type nemotronEventContent struct {
|
||||||
|
content string
|
||||||
|
}
|
||||||
|
|
||||||
|
type nemotronEventToolCall struct {
|
||||||
|
toolCall api.ToolCall
|
||||||
|
}
|
||||||
|
|
||||||
|
func (nemotronEventThinkingContent) isNemotronEvent() {}
|
||||||
|
func (nemotronEventContent) isNemotronEvent() {}
|
||||||
|
func (nemotronEventToolCall) isNemotronEvent() {}
|
||||||
|
|
||||||
|
func (p *Nemotron3NanoParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
|
||||||
|
p.buffer.WriteString(s)
|
||||||
|
events := p.parseEvents()
|
||||||
|
|
||||||
|
var toolCalls []api.ToolCall
|
||||||
|
var contentSb strings.Builder
|
||||||
|
var thinkingSb strings.Builder
|
||||||
|
for _, event := range events {
|
||||||
|
switch event := event.(type) {
|
||||||
|
case nemotronEventToolCall:
|
||||||
|
toolCalls = append(toolCalls, event.toolCall)
|
||||||
|
case nemotronEventThinkingContent:
|
||||||
|
thinkingSb.WriteString(event.content)
|
||||||
|
case nemotronEventContent:
|
||||||
|
contentSb.WriteString(event.content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return contentSb.String(), thinkingSb.String(), toolCalls, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Nemotron3NanoParser) parseEvents() []nemotronEvent {
|
||||||
|
var all []nemotronEvent
|
||||||
|
|
||||||
|
keepLooping := true
|
||||||
|
for keepLooping {
|
||||||
|
var events []nemotronEvent
|
||||||
|
events, keepLooping = p.eat()
|
||||||
|
if len(events) > 0 {
|
||||||
|
all = append(all, events...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return all
|
||||||
|
}
|
||||||
|
|
||||||
|
// emitWithPartialCheck extracts unambiguous content before a potential partial tag
|
||||||
|
func (p *Nemotron3NanoParser) emitWithPartialCheck(bufStr, tag string) (unambiguous, ambiguous string) {
|
||||||
|
if overlapLen := overlap(bufStr, tag); overlapLen > 0 {
|
||||||
|
beforePartialTag := bufStr[:len(bufStr)-overlapLen]
|
||||||
|
trailingLen := trailingWhitespaceLen(beforePartialTag)
|
||||||
|
return bufStr[:len(beforePartialTag)-trailingLen], bufStr[len(beforePartialTag)-trailingLen:]
|
||||||
|
}
|
||||||
|
wsLen := trailingWhitespaceLen(bufStr)
|
||||||
|
return bufStr[:len(bufStr)-wsLen], bufStr[len(bufStr)-wsLen:]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Nemotron3NanoParser) eat() ([]nemotronEvent, bool) {
|
||||||
|
bufStr := p.buffer.String()
|
||||||
|
if bufStr == "" {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
switch p.state {
|
||||||
|
case Nemotron3NanoCollectingThinking:
|
||||||
|
if strings.Contains(bufStr, nemotronThinkClose) {
|
||||||
|
split := strings.SplitN(bufStr, nemotronThinkClose, 2)
|
||||||
|
thinking := strings.TrimRightFunc(split[0], unicode.IsSpace)
|
||||||
|
p.buffer.Reset()
|
||||||
|
remainder := strings.TrimLeftFunc(split[1], unicode.IsSpace)
|
||||||
|
p.buffer.WriteString(remainder)
|
||||||
|
// Transition to whitespace-skipping state if buffer is empty,
|
||||||
|
// otherwise go directly to content collection
|
||||||
|
if remainder == "" {
|
||||||
|
p.state = Nemotron3NanoSkipWhitespaceAfterThinking
|
||||||
|
} else {
|
||||||
|
p.state = Nemotron3NanoCollectingContent
|
||||||
|
}
|
||||||
|
if thinking != "" {
|
||||||
|
return []nemotronEvent{nemotronEventThinkingContent{content: thinking}}, true
|
||||||
|
}
|
||||||
|
return nil, true
|
||||||
|
}
|
||||||
|
unambig, ambig := p.emitWithPartialCheck(bufStr, nemotronThinkClose)
|
||||||
|
p.buffer.Reset()
|
||||||
|
p.buffer.WriteString(ambig)
|
||||||
|
if unambig != "" {
|
||||||
|
return []nemotronEvent{nemotronEventThinkingContent{content: unambig}}, false
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
|
||||||
|
// We only want to skip whitespace between thinking and content
|
||||||
|
case Nemotron3NanoSkipWhitespaceAfterThinking:
|
||||||
|
bufStr = strings.TrimLeftFunc(bufStr, unicode.IsSpace)
|
||||||
|
p.buffer.Reset()
|
||||||
|
p.buffer.WriteString(bufStr)
|
||||||
|
if bufStr == "" {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
p.state = Nemotron3NanoCollectingContent
|
||||||
|
return nil, true
|
||||||
|
|
||||||
|
case Nemotron3NanoCollectingContent:
|
||||||
|
if strings.Contains(bufStr, nemotronToolCallOpen) {
|
||||||
|
split := strings.SplitN(bufStr, nemotronToolCallOpen, 2)
|
||||||
|
content := strings.TrimRightFunc(split[0], unicode.IsSpace)
|
||||||
|
p.buffer.Reset()
|
||||||
|
p.buffer.WriteString(split[1])
|
||||||
|
p.state = Nemotron3NanoCollectingToolCalls
|
||||||
|
if content != "" {
|
||||||
|
return []nemotronEvent{nemotronEventContent{content: content}}, true
|
||||||
|
}
|
||||||
|
return nil, true
|
||||||
|
}
|
||||||
|
unambig, ambig := p.emitWithPartialCheck(bufStr, nemotronToolCallOpen)
|
||||||
|
p.buffer.Reset()
|
||||||
|
p.buffer.WriteString(ambig)
|
||||||
|
if unambig != "" {
|
||||||
|
return []nemotronEvent{nemotronEventContent{content: unambig}}, false
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
|
||||||
|
case Nemotron3NanoCollectingToolCalls:
|
||||||
|
if strings.Contains(bufStr, nemotronToolCallClose) {
|
||||||
|
split := strings.SplitN(bufStr, nemotronToolCallClose, 2)
|
||||||
|
remaining := strings.TrimLeftFunc(split[1], unicode.IsSpace)
|
||||||
|
p.buffer.Reset()
|
||||||
|
p.buffer.WriteString(remaining)
|
||||||
|
|
||||||
|
var events []nemotronEvent
|
||||||
|
if tc, err := p.parseToolCall(split[0]); err == nil {
|
||||||
|
events = append(events, nemotronEventToolCall{toolCall: tc})
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(remaining, nemotronToolCallOpen) {
|
||||||
|
p.state = Nemotron3NanoCollectingContent
|
||||||
|
}
|
||||||
|
return events, true
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
nemotronFunctionRegex = regexp.MustCompile(`<function=([^>]+)>`)
|
||||||
|
nemotronParameterRegex = regexp.MustCompile(`<parameter=([^>]+)>\n?([\s\S]*?)\n?</parameter>`)
|
||||||
|
)
|
||||||
|
|
||||||
|
func (p *Nemotron3NanoParser) parseToolCall(content string) (api.ToolCall, error) {
|
||||||
|
toolCall := api.ToolCall{}
|
||||||
|
|
||||||
|
// Extract function name
|
||||||
|
fnMatch := nemotronFunctionRegex.FindStringSubmatch(content)
|
||||||
|
if len(fnMatch) < 2 {
|
||||||
|
return toolCall, nil
|
||||||
|
}
|
||||||
|
toolCall.Function.Name = fnMatch[1]
|
||||||
|
|
||||||
|
// Extract parameters
|
||||||
|
toolCall.Function.Arguments = make(api.ToolCallFunctionArguments)
|
||||||
|
paramMatches := nemotronParameterRegex.FindAllStringSubmatch(content, -1)
|
||||||
|
for _, match := range paramMatches {
|
||||||
|
if len(match) >= 3 {
|
||||||
|
paramName := match[1]
|
||||||
|
paramValue := strings.TrimSpace(match[2])
|
||||||
|
|
||||||
|
// Try to parse as typed value based on tool definition
|
||||||
|
toolCall.Function.Arguments[paramName] = p.parseParamValue(paramName, paramValue)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return toolCall, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Nemotron3NanoParser) parseParamValue(paramName string, raw string) any {
|
||||||
|
// Find the matching tool to get parameter type
|
||||||
|
var paramType api.PropertyType
|
||||||
|
for _, tool := range p.tools {
|
||||||
|
if prop, ok := tool.Function.Parameters.Properties[paramName]; ok {
|
||||||
|
paramType = prop.Type
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return parseValue(raw, paramType)
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,574 @@
|
||||||
|
package parsers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNemotron3NanoParser(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
thinkValue *api.ThinkValue
|
||||||
|
expectedContent string
|
||||||
|
expectedThinking string
|
||||||
|
expectedCalls []api.ToolCall
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "simple content - no thinking",
|
||||||
|
input: "Hello, how can I help you?",
|
||||||
|
thinkValue: nil,
|
||||||
|
expectedContent: "Hello, how can I help you?",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "simple content - thinking disabled",
|
||||||
|
input: "Hello, how can I help you?",
|
||||||
|
thinkValue: &api.ThinkValue{Value: false},
|
||||||
|
expectedContent: "Hello, how can I help you?",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "thinking then content",
|
||||||
|
input: "Let me think about this...</think>\nHere is my answer.",
|
||||||
|
thinkValue: &api.ThinkValue{Value: true},
|
||||||
|
expectedThinking: "Let me think about this...",
|
||||||
|
expectedContent: "Here is my answer.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "thinking with newlines",
|
||||||
|
input: "Step 1: Analyze\nStep 2: Process\nStep 3: Conclude</think>\nThe answer is 42.",
|
||||||
|
thinkValue: &api.ThinkValue{Value: true},
|
||||||
|
expectedThinking: "Step 1: Analyze\nStep 2: Process\nStep 3: Conclude",
|
||||||
|
expectedContent: "The answer is 42.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "simple tool call",
|
||||||
|
input: "<tool_call>\n<function=get_weather>\n<parameter=city>\nParis\n</parameter>\n</function>\n</tool_call>",
|
||||||
|
thinkValue: nil,
|
||||||
|
expectedCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Arguments: map[string]any{"city": "Paris"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "content then tool call",
|
||||||
|
input: "Let me check the weather.\n<tool_call>\n<function=get_weather>\n<parameter=city>\nNYC\n</parameter>\n</function>\n</tool_call>",
|
||||||
|
thinkValue: nil,
|
||||||
|
expectedContent: "Let me check the weather.",
|
||||||
|
expectedCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Arguments: map[string]any{"city": "NYC"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tool call with multiple parameters",
|
||||||
|
input: "<tool_call>\n<function=book_flight>\n<parameter=from>\nSFO\n</parameter>\n<parameter=to>\nNYC\n</parameter>\n</function>\n</tool_call>",
|
||||||
|
thinkValue: nil,
|
||||||
|
expectedCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "book_flight",
|
||||||
|
Arguments: map[string]any{
|
||||||
|
"from": "SFO",
|
||||||
|
"to": "NYC",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple tool calls",
|
||||||
|
input: "<tool_call>\n<function=get_weather>\n<parameter=city>\nSan Francisco\n</parameter>\n</function>\n</tool_call>\n" +
|
||||||
|
"<tool_call>\n<function=get_weather>\n<parameter=city>\nNew York\n</parameter>\n</function>\n</tool_call>",
|
||||||
|
thinkValue: nil,
|
||||||
|
expectedCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Arguments: map[string]any{"city": "San Francisco"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Arguments: map[string]any{"city": "New York"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "thinking then tool call",
|
||||||
|
input: "I should check the weather...</think>\n<tool_call>\n<function=get_weather>\n<parameter=city>\nParis\n</parameter>\n</function>\n</tool_call>",
|
||||||
|
thinkValue: &api.ThinkValue{Value: true},
|
||||||
|
expectedThinking: "I should check the weather...",
|
||||||
|
expectedCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Arguments: map[string]any{"city": "Paris"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "thinking content then tool call",
|
||||||
|
input: "Let me think...</think>\nI'll check for you.\n<tool_call>\n<function=search>\n<parameter=query>\ntest\n</parameter>\n</function>\n</tool_call>",
|
||||||
|
thinkValue: &api.ThinkValue{Value: true},
|
||||||
|
expectedThinking: "Let me think...",
|
||||||
|
expectedContent: "I'll check for you.",
|
||||||
|
expectedCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "search",
|
||||||
|
Arguments: map[string]any{"query": "test"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tool call with multiline parameter value",
|
||||||
|
input: "<tool_call>\n<function=create_note>\n<parameter=content>\nLine 1\nLine 2\nLine 3\n</parameter>\n</function>\n</tool_call>",
|
||||||
|
thinkValue: nil,
|
||||||
|
expectedCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "create_note",
|
||||||
|
Arguments: map[string]any{"content": "Line 1\nLine 2\nLine 3"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty thinking block - immediate close",
|
||||||
|
input: "</think>\nHere is my answer.",
|
||||||
|
thinkValue: &api.ThinkValue{Value: true},
|
||||||
|
expectedThinking: "",
|
||||||
|
expectedContent: "Here is my answer.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "thinking disabled but model outputs think close anyway",
|
||||||
|
input: "</think>\nSome content after spurious tag.",
|
||||||
|
thinkValue: &api.ThinkValue{Value: false},
|
||||||
|
expectedContent: "</think>\nSome content after spurious tag.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tool call with no function name - returns empty tool call",
|
||||||
|
input: "<tool_call>\n<function=>\n</function>\n</tool_call>",
|
||||||
|
thinkValue: nil,
|
||||||
|
expectedCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "", Arguments: nil}}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "content with newlines preserved",
|
||||||
|
input: "Line 1\n\nLine 2\n\n\nLine 3",
|
||||||
|
thinkValue: nil,
|
||||||
|
expectedContent: "Line 1\n\nLine 2\n\n\nLine 3",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "thinking with only whitespace after close tag",
|
||||||
|
input: "My thoughts...</think> \n\t\n Content here.",
|
||||||
|
thinkValue: &api.ThinkValue{Value: true},
|
||||||
|
expectedThinking: "My thoughts...",
|
||||||
|
expectedContent: "Content here.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unicode content",
|
||||||
|
input: "Hello 世界! 🌍 Ñoño",
|
||||||
|
thinkValue: nil,
|
||||||
|
expectedContent: "Hello 世界! 🌍 Ñoño",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tool call with numeric parameter",
|
||||||
|
input: "<tool_call>\n<function=set_temp>\n<parameter=value>\n42\n</parameter>\n</function>\n</tool_call>",
|
||||||
|
thinkValue: nil,
|
||||||
|
expectedCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "set_temp",
|
||||||
|
Arguments: map[string]any{"value": "42"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
p := &Nemotron3NanoParser{}
|
||||||
|
p.Init(nil, nil, tt.thinkValue)
|
||||||
|
|
||||||
|
content, thinking, calls, err := p.Add(tt.input, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Drain remaining content
|
||||||
|
finalContent, finalThinking, finalCalls, err := p.Add("", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error on done: %v", err)
|
||||||
|
}
|
||||||
|
content += finalContent
|
||||||
|
thinking += finalThinking
|
||||||
|
calls = append(calls, finalCalls...)
|
||||||
|
|
||||||
|
if diff := cmp.Diff(content, tt.expectedContent); diff != "" {
|
||||||
|
t.Errorf("content mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(thinking, tt.expectedThinking); diff != "" {
|
||||||
|
t.Errorf("thinking mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(calls, tt.expectedCalls); diff != "" {
|
||||||
|
t.Errorf("calls mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNemotron3NanoParser_Streaming(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
chunks []string
|
||||||
|
thinkValue *api.ThinkValue
|
||||||
|
expectedContent string
|
||||||
|
expectedThinking string
|
||||||
|
expectedCalls []api.ToolCall
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "streaming content character by character",
|
||||||
|
chunks: []string{"H", "e", "l", "l", "o", ",", " ", "w", "o", "r", "l", "d", "!"},
|
||||||
|
thinkValue: nil,
|
||||||
|
expectedContent: "Hello, world!",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "streaming content small tokens",
|
||||||
|
chunks: []string{"Hel", "lo", ", ", "how ", "can", " I", " help", " you", " today", "?"},
|
||||||
|
thinkValue: nil,
|
||||||
|
expectedContent: "Hello, how can I help you today?",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "streaming thinking then content - granular",
|
||||||
|
chunks: []string{"Let", " me", " th", "ink", " about", " this", "...", "<", "/", "think", ">", "\n", "Here", " is", " my", " answer", "."},
|
||||||
|
thinkValue: &api.ThinkValue{Value: true},
|
||||||
|
expectedThinking: "Let me think about this...",
|
||||||
|
expectedContent: "Here is my answer.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "streaming thinking with newlines - granular",
|
||||||
|
chunks: []string{"Step", " 1", ":", " Ana", "lyze\n", "Step", " 2", ":", " Pro", "cess", "</", "thi", "nk>", "\n", "The", " ans", "wer."},
|
||||||
|
thinkValue: &api.ThinkValue{Value: true},
|
||||||
|
expectedThinking: "Step 1: Analyze\nStep 2: Process",
|
||||||
|
expectedContent: "The answer.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "streaming tool call - highly granular",
|
||||||
|
chunks: []string{"<", "tool", "_", "call", ">", "\n", "<", "func", "tion", "=", "get", "_", "weather", ">", "\n", "<", "param", "eter", "=", "city", ">", "\n", "Par", "is", "\n", "</", "param", "eter", ">", "\n", "</", "func", "tion", ">", "\n", "</", "tool", "_", "call", ">"},
|
||||||
|
thinkValue: nil,
|
||||||
|
expectedCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Arguments: map[string]any{"city": "Paris"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "streaming content then tool call - granular",
|
||||||
|
chunks: []string{"Let", " me", " check", " the", " weather", ".", "\n<", "tool_call", ">", "\n", "<function=", "get_weather", ">", "\n", "<parameter=", "city", ">", "\n", "NYC", "\n", "</parameter>", "\n", "</function>", "\n", "</tool_call>"},
|
||||||
|
thinkValue: nil,
|
||||||
|
expectedContent: "Let me check the weather.",
|
||||||
|
expectedCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Arguments: map[string]any{"city": "NYC"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tool call tag split character by character",
|
||||||
|
chunks: []string{"<", "t", "o", "o", "l", "_", "c", "a", "l", "l", ">", "\n", "<", "f", "u", "n", "c", "t", "i", "o", "n", "=", "t", "e", "s", "t", ">", "\n", "<", "/", "f", "u", "n", "c", "t", "i", "o", "n", ">", "\n", "<", "/", "t", "o", "o", "l", "_", "c", "a", "l", "l", ">"},
|
||||||
|
expectedCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "test",
|
||||||
|
Arguments: map[string]any{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "thinking close tag split character by character",
|
||||||
|
chunks: []string{"I", "'", "m", " ", "t", "h", "i", "n", "k", "i", "n", "g", ".", ".", ".", "<", "/", "t", "h", "i", "n", "k", ">", "\n", "D", "o", "n", "e", "!"},
|
||||||
|
thinkValue: &api.ThinkValue{Value: true},
|
||||||
|
expectedThinking: "I'm thinking...",
|
||||||
|
expectedContent: "Done!",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple whitespace after think tag - separate chunks",
|
||||||
|
chunks: []string{"Thinking...", "</think>", "\n", "\n", " ", "Content here."},
|
||||||
|
thinkValue: &api.ThinkValue{Value: true},
|
||||||
|
expectedThinking: "Thinking...",
|
||||||
|
expectedContent: "Content here.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tool call with multiple parameters - streaming",
|
||||||
|
chunks: []string{"<tool_", "call>\n", "<function", "=book_", "flight>", "\n<para", "meter=", "from>\n", "SFO\n", "</param", "eter>", "\n<param", "eter=to", ">\nNYC", "\n</para", "meter>", "\n</func", "tion>\n", "</tool_", "call>"},
|
||||||
|
thinkValue: nil,
|
||||||
|
expectedCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "book_flight",
|
||||||
|
Arguments: map[string]any{
|
||||||
|
"from": "SFO",
|
||||||
|
"to": "NYC",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "thinking then content then tool call - streaming",
|
||||||
|
chunks: []string{"Ana", "lyzing", " your", " request", "...", "</", "think", ">\n", "I'll", " check", " that", " for", " you", ".", "\n", "<tool", "_call", ">\n", "<function", "=search", ">\n", "<parameter", "=query", ">\n", "test", " query", "\n</", "parameter", ">\n", "</function", ">\n", "</tool", "_call", ">"},
|
||||||
|
thinkValue: &api.ThinkValue{Value: true},
|
||||||
|
expectedThinking: "Analyzing your request...",
|
||||||
|
expectedContent: "I'll check that for you.",
|
||||||
|
expectedCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "search",
|
||||||
|
Arguments: map[string]any{"query": "test query"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple tool calls - streaming",
|
||||||
|
chunks: []string{
|
||||||
|
"<tool_call>", "\n", "<function=", "get_weather>", "\n",
|
||||||
|
"<parameter=", "city>\n", "San Fran", "cisco\n", "</parameter>", "\n",
|
||||||
|
"</function>", "\n", "</tool_call>", "\n",
|
||||||
|
"<tool_", "call>\n", "<function", "=get_weather", ">\n",
|
||||||
|
"<param", "eter=city", ">\nNew", " York\n", "</parameter>\n",
|
||||||
|
"</function>\n", "</tool_call>",
|
||||||
|
},
|
||||||
|
thinkValue: nil,
|
||||||
|
expectedCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Arguments: map[string]any{"city": "San Francisco"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Arguments: map[string]any{"city": "New York"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tool call with multiline parameter - streaming",
|
||||||
|
chunks: []string{"<tool_call>\n", "<function=", "create_note>\n", "<parameter=", "content>\n", "Line 1", "\nLine", " 2\n", "Line 3", "\n</parameter>\n", "</function>\n", "</tool_call>"},
|
||||||
|
thinkValue: nil,
|
||||||
|
expectedCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "create_note",
|
||||||
|
Arguments: map[string]any{"content": "Line 1\nLine 2\nLine 3"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty thinking block",
|
||||||
|
chunks: []string{"</think>", "\n", "Just content."},
|
||||||
|
thinkValue: &api.ThinkValue{Value: true},
|
||||||
|
expectedThinking: "",
|
||||||
|
expectedContent: "Just content.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty input chunks interspersed",
|
||||||
|
chunks: []string{"Hello", "", " ", "", "world", "", "!"},
|
||||||
|
thinkValue: nil,
|
||||||
|
expectedContent: "Hello world!",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tool call immediately after think close - no content",
|
||||||
|
chunks: []string{"Analyzing...", "</think>", "\n", "<tool_call>", "\n<function=test>\n</function>\n", "</tool_call>"},
|
||||||
|
thinkValue: &api.ThinkValue{Value: true},
|
||||||
|
expectedThinking: "Analyzing...",
|
||||||
|
expectedCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "test",
|
||||||
|
Arguments: map[string]any{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tool call with empty parameter value",
|
||||||
|
chunks: []string{"<tool_call>\n<function=test>\n<parameter=name>\n", "\n</parameter>\n</function>\n</tool_call>"},
|
||||||
|
thinkValue: nil,
|
||||||
|
expectedCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "test",
|
||||||
|
Arguments: map[string]any{"name": ""},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "partial tool call tag at end - buffered",
|
||||||
|
chunks: []string{"Here's some content", "<tool"},
|
||||||
|
thinkValue: nil,
|
||||||
|
expectedContent: "Here's some content",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
p := &Nemotron3NanoParser{}
|
||||||
|
p.Init(nil, nil, tt.thinkValue)
|
||||||
|
|
||||||
|
var allContent string
|
||||||
|
var allThinking string
|
||||||
|
var allCalls []api.ToolCall
|
||||||
|
|
||||||
|
for _, chunk := range tt.chunks {
|
||||||
|
content, thinking, calls, err := p.Add(chunk, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
allContent += content
|
||||||
|
allThinking += thinking
|
||||||
|
allCalls = append(allCalls, calls...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Drain
|
||||||
|
content, thinking, calls, err := p.Add("", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error on done: %v", err)
|
||||||
|
}
|
||||||
|
allContent += content
|
||||||
|
allThinking += thinking
|
||||||
|
allCalls = append(allCalls, calls...)
|
||||||
|
|
||||||
|
if diff := cmp.Diff(allContent, tt.expectedContent); diff != "" {
|
||||||
|
t.Errorf("content mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(allThinking, tt.expectedThinking); diff != "" {
|
||||||
|
t.Errorf("thinking mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(allCalls, tt.expectedCalls); diff != "" {
|
||||||
|
t.Errorf("calls mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNemotron3NanoParser_HasToolSupport(t *testing.T) {
|
||||||
|
p := &Nemotron3NanoParser{}
|
||||||
|
if !p.HasToolSupport() {
|
||||||
|
t.Error("expected HasToolSupport to return true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNemotron3NanoParser_HasThinkingSupport(t *testing.T) {
|
||||||
|
p := &Nemotron3NanoParser{}
|
||||||
|
if !p.HasThinkingSupport() {
|
||||||
|
t.Error("expected HasThinkingSupport to return true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNemotron3NanoParser_Init(t *testing.T) {
|
||||||
|
t.Run("starts in thinking state when enabled", func(t *testing.T) {
|
||||||
|
p := &Nemotron3NanoParser{}
|
||||||
|
p.Init(nil, nil, &api.ThinkValue{Value: true})
|
||||||
|
if p.state != Nemotron3NanoCollectingThinking {
|
||||||
|
t.Errorf("expected state Nemotron3NanoCollectingThinking, got %v", p.state)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("starts in content state when thinking disabled", func(t *testing.T) {
|
||||||
|
p := &Nemotron3NanoParser{}
|
||||||
|
p.Init(nil, nil, &api.ThinkValue{Value: false})
|
||||||
|
if p.state != Nemotron3NanoCollectingContent {
|
||||||
|
t.Errorf("expected state Nemotron3NanoCollectingContent, got %v", p.state)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("starts in content state when nil thinkValue", func(t *testing.T) {
|
||||||
|
p := &Nemotron3NanoParser{}
|
||||||
|
p.Init(nil, nil, nil)
|
||||||
|
if p.state != Nemotron3NanoCollectingContent {
|
||||||
|
t.Errorf("expected state Nemotron3NanoCollectingContent, got %v", p.state)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("starts in content state with assistant prefill", func(t *testing.T) {
|
||||||
|
p := &Nemotron3NanoParser{}
|
||||||
|
prefill := &api.Message{Role: "assistant", Content: "Starting..."}
|
||||||
|
p.Init(nil, prefill, &api.ThinkValue{Value: true})
|
||||||
|
if p.state != Nemotron3NanoCollectingContent {
|
||||||
|
t.Errorf("expected state Nemotron3NanoCollectingContent, got %v", p.state)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNemotron3NanoParser_WithTools(t *testing.T) {
|
||||||
|
tools := []api.Tool{
|
||||||
|
{
|
||||||
|
Type: "function",
|
||||||
|
Function: api.ToolFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Parameters: api.ToolFunctionParameters{
|
||||||
|
Type: "object",
|
||||||
|
Properties: map[string]api.ToolProperty{
|
||||||
|
"city": {Type: api.PropertyType{"string"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
p := &Nemotron3NanoParser{}
|
||||||
|
returnedTools := p.Init(tools, nil, nil)
|
||||||
|
|
||||||
|
if diff := cmp.Diff(returnedTools, tools); diff != "" {
|
||||||
|
t.Errorf("tools mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse a tool call
|
||||||
|
input := "<tool_call>\n<function=get_weather>\n<parameter=city>\nParis\n</parameter>\n</function>\n</tool_call>"
|
||||||
|
_, _, calls, err := p.Add(input, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedCalls := []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Arguments: map[string]any{"city": "Paris"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(calls, expectedCalls); diff != "" {
|
||||||
|
t.Errorf("calls mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -58,10 +58,14 @@ func ParserForName(name string) Parser {
|
||||||
return harmony.NewHarmonyMessageHandler()
|
return harmony.NewHarmonyMessageHandler()
|
||||||
case "cogito":
|
case "cogito":
|
||||||
return &CogitoParser{}
|
return &CogitoParser{}
|
||||||
|
case "deepseek3":
|
||||||
|
return &DeepSeek3Parser{hasThinkingSupport: true}
|
||||||
case "olmo3":
|
case "olmo3":
|
||||||
return &Olmo3Parser{}
|
return &Olmo3Parser{}
|
||||||
case "olmo3-think":
|
case "olmo3-think":
|
||||||
return &Olmo3ThinkParser{}
|
return &Olmo3ThinkParser{}
|
||||||
|
case "nemotron-3-nano":
|
||||||
|
return &Nemotron3NanoParser{}
|
||||||
default:
|
default:
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,157 @@
|
||||||
|
package renderers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
type DeepSeek3Variant int
|
||||||
|
|
||||||
|
const (
|
||||||
|
Deepseek31 DeepSeek3Variant = iota
|
||||||
|
)
|
||||||
|
|
||||||
|
type DeepSeek3Renderer struct {
|
||||||
|
IsThinking bool
|
||||||
|
Variant DeepSeek3Variant
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *DeepSeek3Renderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
|
||||||
|
var sb strings.Builder
|
||||||
|
|
||||||
|
// thinking is enabled: model must support it AND user must request it
|
||||||
|
thinking := r.IsThinking && (thinkValue != nil && thinkValue.Bool())
|
||||||
|
|
||||||
|
// extract system messages first
|
||||||
|
var systemPrompt strings.Builder
|
||||||
|
isFirstSystemPrompt := true
|
||||||
|
|
||||||
|
for _, message := range messages {
|
||||||
|
if message.Role == "system" {
|
||||||
|
if isFirstSystemPrompt {
|
||||||
|
systemPrompt.WriteString(message.Content)
|
||||||
|
isFirstSystemPrompt = false
|
||||||
|
} else {
|
||||||
|
systemPrompt.WriteString("\n\n" + message.Content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sb.WriteString("<|begin▁of▁sentence|>")
|
||||||
|
sb.WriteString(systemPrompt.String())
|
||||||
|
|
||||||
|
// tool definitions
|
||||||
|
if len(tools) > 0 {
|
||||||
|
sb.WriteString("\n\n## Tools\nYou have access to the following tools:\n")
|
||||||
|
|
||||||
|
for _, tool := range tools {
|
||||||
|
sb.WriteString("\n### " + tool.Function.Name)
|
||||||
|
sb.WriteString("\nDescription: " + tool.Function.Description)
|
||||||
|
|
||||||
|
// parameters as JSON
|
||||||
|
parametersJSON, err := json.Marshal(tool.Function.Parameters)
|
||||||
|
if err == nil {
|
||||||
|
sb.WriteString("\n\nParameters: " + string(parametersJSON) + "\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// usage instructions
|
||||||
|
sb.WriteString("\nIMPORTANT: ALWAYS adhere to this exact format for tool use:\n")
|
||||||
|
sb.WriteString("<|tool▁calls▁begin|><|tool▁call▁begin|>tool_call_name<|tool▁sep|>tool_call_arguments<|tool▁call▁end|>{{additional_tool_calls}}<|tool▁calls▁end|>\n\n")
|
||||||
|
sb.WriteString("Where:\n\n")
|
||||||
|
sb.WriteString("- `tool_call_name` must be an exact match to one of the available tools\n")
|
||||||
|
sb.WriteString("- `tool_call_arguments` must be valid JSON that strictly follows the tool's Parameters Schema\n")
|
||||||
|
sb.WriteString("- For multiple tool calls, chain them directly without separators or spaces\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
// state tracking
|
||||||
|
isTool := false
|
||||||
|
isLastUser := false
|
||||||
|
|
||||||
|
// Find the index of the last user message to determine which assistant message is "current"
|
||||||
|
lastUserIndex := -1
|
||||||
|
for i := len(messages) - 1; i >= 0; i-- {
|
||||||
|
if messages[i].Role == "user" {
|
||||||
|
lastUserIndex = i
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, message := range messages {
|
||||||
|
switch message.Role {
|
||||||
|
case "user":
|
||||||
|
isTool = false
|
||||||
|
isLastUser = true
|
||||||
|
sb.WriteString("<|User|>" + message.Content)
|
||||||
|
|
||||||
|
case "assistant":
|
||||||
|
if len(message.ToolCalls) > 0 {
|
||||||
|
if isLastUser {
|
||||||
|
sb.WriteString("<|Assistant|></think>")
|
||||||
|
}
|
||||||
|
isLastUser = false
|
||||||
|
isTool = false
|
||||||
|
|
||||||
|
if message.Content != "" {
|
||||||
|
sb.WriteString(message.Content)
|
||||||
|
}
|
||||||
|
|
||||||
|
sb.WriteString("<|tool▁calls▁begin|>")
|
||||||
|
for _, toolCall := range message.ToolCalls {
|
||||||
|
sb.WriteString("<|tool▁call▁begin|>" + toolCall.Function.Name + "<|tool▁sep|>")
|
||||||
|
|
||||||
|
argsJSON, _ := json.Marshal(toolCall.Function.Arguments)
|
||||||
|
sb.WriteString(string(argsJSON))
|
||||||
|
sb.WriteString("<|tool▁call▁end|>")
|
||||||
|
}
|
||||||
|
sb.WriteString("<|tool▁calls▁end|><|end▁of▁sentence|>")
|
||||||
|
} else {
|
||||||
|
if isLastUser {
|
||||||
|
sb.WriteString("<|Assistant|>")
|
||||||
|
hasThinking := message.Thinking != ""
|
||||||
|
|
||||||
|
// only use <think> for the current turn (after last user message)
|
||||||
|
isCurrentTurn := i > lastUserIndex
|
||||||
|
if hasThinking && thinking && isCurrentTurn {
|
||||||
|
sb.WriteString("<think>")
|
||||||
|
} else {
|
||||||
|
sb.WriteString("</think>")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
isLastUser = false
|
||||||
|
|
||||||
|
content := message.Content
|
||||||
|
if isTool {
|
||||||
|
sb.WriteString(content + "<|end▁of▁sentence|>")
|
||||||
|
isTool = false
|
||||||
|
} else {
|
||||||
|
if strings.Contains(content, "</think>") {
|
||||||
|
parts := strings.SplitN(content, "</think>", 2)
|
||||||
|
if len(parts) > 1 {
|
||||||
|
content = parts[1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sb.WriteString(content + "<|end▁of▁sentence|>")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
case "tool":
|
||||||
|
isLastUser = false
|
||||||
|
isTool = true
|
||||||
|
sb.WriteString("<|tool▁output▁begin|>" + message.Content + "<|tool▁output▁end|>")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if isLastUser && !isTool {
|
||||||
|
sb.WriteString("<|Assistant|>")
|
||||||
|
if thinking {
|
||||||
|
sb.WriteString("<think>")
|
||||||
|
} else {
|
||||||
|
sb.WriteString("</think>")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return sb.String(), nil
|
||||||
|
}
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,220 @@
|
||||||
|
package renderers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Nemotron3NanoRenderer struct{}
|
||||||
|
|
||||||
|
func (r *Nemotron3NanoRenderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
|
||||||
|
var sb strings.Builder
|
||||||
|
|
||||||
|
// thinking is enabled if user requests it
|
||||||
|
enableThinking := thinkValue != nil && thinkValue.Bool()
|
||||||
|
|
||||||
|
// Extract system message if present
|
||||||
|
var systemMessage string
|
||||||
|
var loopMessages []api.Message
|
||||||
|
if len(messages) > 0 && messages[0].Role == "system" {
|
||||||
|
systemMessage = messages[0].Content
|
||||||
|
loopMessages = messages[1:]
|
||||||
|
} else {
|
||||||
|
loopMessages = messages
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find last user message index for thinking truncation
|
||||||
|
lastUserIdx := -1
|
||||||
|
for i, msg := range loopMessages {
|
||||||
|
if msg.Role == "user" {
|
||||||
|
lastUserIdx = i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sb.WriteString("<|im_start|>system\n")
|
||||||
|
if systemMessage != "" {
|
||||||
|
sb.WriteString(systemMessage)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(tools) > 0 {
|
||||||
|
if systemMessage != "" {
|
||||||
|
sb.WriteString("\n\n")
|
||||||
|
}
|
||||||
|
sb.WriteString(r.renderTools(tools))
|
||||||
|
}
|
||||||
|
sb.WriteString("<|im_end|>\n")
|
||||||
|
|
||||||
|
for i, message := range loopMessages {
|
||||||
|
switch message.Role {
|
||||||
|
case "assistant":
|
||||||
|
// Build content with thinking tags
|
||||||
|
content := r.buildContent(message)
|
||||||
|
shouldTruncate := i < lastUserIdx
|
||||||
|
|
||||||
|
if len(message.ToolCalls) > 0 {
|
||||||
|
sb.WriteString("<|im_start|>assistant\n")
|
||||||
|
sb.WriteString(r.formatContent(content, shouldTruncate, true))
|
||||||
|
r.writeToolCalls(&sb, message.ToolCalls)
|
||||||
|
sb.WriteString("<|im_end|>\n")
|
||||||
|
} else {
|
||||||
|
formatted := r.formatContent(content, shouldTruncate, false)
|
||||||
|
sb.WriteString("<|im_start|>assistant\n" + formatted + "<|im_end|>\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
case "user", "system":
|
||||||
|
sb.WriteString("<|im_start|>" + message.Role + "\n")
|
||||||
|
sb.WriteString(message.Content)
|
||||||
|
sb.WriteString("<|im_end|>\n")
|
||||||
|
|
||||||
|
case "tool":
|
||||||
|
// Check if previous message was also a tool message
|
||||||
|
prevWasTool := i > 0 && loopMessages[i-1].Role == "tool"
|
||||||
|
nextIsTool := i+1 < len(loopMessages) && loopMessages[i+1].Role == "tool"
|
||||||
|
|
||||||
|
if !prevWasTool {
|
||||||
|
sb.WriteString("<|im_start|>user\n")
|
||||||
|
}
|
||||||
|
sb.WriteString("<tool_response>\n")
|
||||||
|
sb.WriteString(message.Content)
|
||||||
|
sb.WriteString("\n</tool_response>\n")
|
||||||
|
|
||||||
|
if !nextIsTool {
|
||||||
|
sb.WriteString("<|im_end|>\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
sb.WriteString("<|im_start|>" + message.Role + "\n" + message.Content + "<|im_end|>\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add generation prompt
|
||||||
|
if enableThinking {
|
||||||
|
sb.WriteString("<|im_start|>assistant\n<think>\n")
|
||||||
|
} else {
|
||||||
|
sb.WriteString("<|im_start|>assistant\n<think></think>")
|
||||||
|
}
|
||||||
|
|
||||||
|
return sb.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Nemotron3NanoRenderer) renderTools(tools []api.Tool) string {
|
||||||
|
var sb strings.Builder
|
||||||
|
sb.WriteString("# Tools\n\nYou have access to the following functions:\n\n<tools>")
|
||||||
|
|
||||||
|
for _, tool := range tools {
|
||||||
|
fn := tool.Function
|
||||||
|
sb.WriteString("\n<function>\n<name>" + fn.Name + "</name>")
|
||||||
|
|
||||||
|
if fn.Description != "" {
|
||||||
|
sb.WriteString("\n<description>" + strings.TrimSpace(fn.Description) + "</description>")
|
||||||
|
}
|
||||||
|
|
||||||
|
sb.WriteString("\n<parameters>")
|
||||||
|
if fn.Parameters.Properties != nil {
|
||||||
|
for paramName, paramFields := range fn.Parameters.Properties {
|
||||||
|
sb.WriteString("\n<parameter>")
|
||||||
|
sb.WriteString("\n<name>" + paramName + "</name>")
|
||||||
|
|
||||||
|
if len(paramFields.Type) > 0 {
|
||||||
|
sb.WriteString("\n<type>" + strings.Join(paramFields.Type, ", ") + "</type>")
|
||||||
|
}
|
||||||
|
|
||||||
|
if paramFields.Description != "" {
|
||||||
|
sb.WriteString("\n<description>" + strings.TrimSpace(paramFields.Description) + "</description>")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(paramFields.Enum) > 0 {
|
||||||
|
enumJSON, _ := json.Marshal(paramFields.Enum)
|
||||||
|
sb.WriteString("\n<enum>" + string(enumJSON) + "</enum>")
|
||||||
|
}
|
||||||
|
|
||||||
|
sb.WriteString("\n</parameter>")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(fn.Parameters.Required) > 0 {
|
||||||
|
reqJSON, _ := json.Marshal(fn.Parameters.Required)
|
||||||
|
sb.WriteString("\n<required>" + string(reqJSON) + "</required>")
|
||||||
|
}
|
||||||
|
|
||||||
|
sb.WriteString("\n</parameters>")
|
||||||
|
sb.WriteString("\n</function>")
|
||||||
|
}
|
||||||
|
|
||||||
|
sb.WriteString("\n</tools>")
|
||||||
|
|
||||||
|
sb.WriteString("\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n" +
|
||||||
|
"<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n" +
|
||||||
|
"<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n" +
|
||||||
|
"</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n" +
|
||||||
|
"- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n" +
|
||||||
|
"- Required parameters MUST be specified\n" +
|
||||||
|
"- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n" +
|
||||||
|
"- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n</IMPORTANT>")
|
||||||
|
|
||||||
|
return sb.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Nemotron3NanoRenderer) buildContent(message api.Message) string {
|
||||||
|
// The parser always extracts thinking into the Thinking field,
|
||||||
|
// so Content will never have <think> tags embedded
|
||||||
|
if message.Thinking != "" {
|
||||||
|
return "<think>\n" + message.Thinking + "\n</think>\n" + message.Content
|
||||||
|
}
|
||||||
|
return "<think></think>" + message.Content
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Nemotron3NanoRenderer) formatContent(content string, truncate bool, addNewline bool) string {
|
||||||
|
if content == "" {
|
||||||
|
return "<think></think>"
|
||||||
|
}
|
||||||
|
|
||||||
|
if !truncate {
|
||||||
|
if addNewline {
|
||||||
|
return strings.TrimSpace(content) + "\n"
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(content)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Truncate thinking - keep only content after </think>
|
||||||
|
c := content
|
||||||
|
if strings.Contains(c, "</think>") {
|
||||||
|
parts := strings.Split(c, "</think>")
|
||||||
|
c = parts[len(parts)-1]
|
||||||
|
} else if strings.Contains(c, "<think>") {
|
||||||
|
parts := strings.Split(c, "<think>")
|
||||||
|
c = parts[0]
|
||||||
|
}
|
||||||
|
c = "<think></think>" + strings.TrimSpace(c)
|
||||||
|
|
||||||
|
if addNewline && len(c) > len("<think></think>") {
|
||||||
|
return c + "\n"
|
||||||
|
}
|
||||||
|
if c == "<think></think>" {
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Nemotron3NanoRenderer) writeToolCalls(sb *strings.Builder, toolCalls []api.ToolCall) {
|
||||||
|
for _, tc := range toolCalls {
|
||||||
|
sb.WriteString("<tool_call>\n<function=" + tc.Function.Name + ">\n")
|
||||||
|
for name, value := range tc.Function.Arguments {
|
||||||
|
sb.WriteString("<parameter=" + name + ">\n" + r.formatArgValue(value) + "\n</parameter>\n")
|
||||||
|
}
|
||||||
|
sb.WriteString("</function>\n</tool_call>\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Nemotron3NanoRenderer) formatArgValue(value any) string {
|
||||||
|
switch v := value.(type) {
|
||||||
|
case map[string]any, []any:
|
||||||
|
jsonBytes, _ := json.Marshal(v)
|
||||||
|
return string(jsonBytes)
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("%v", v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,569 @@
|
||||||
|
package renderers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNemotron3NanoRenderer(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
msgs []api.Message
|
||||||
|
tools []api.Tool
|
||||||
|
thinkValue *api.ThinkValue
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "basic user message - thinking mode",
|
||||||
|
msgs: []api.Message{
|
||||||
|
{Role: "user", Content: "Hello!"},
|
||||||
|
},
|
||||||
|
thinkValue: &api.ThinkValue{Value: true},
|
||||||
|
expected: "<|im_start|>system\n<|im_end|>\n" +
|
||||||
|
"<|im_start|>user\nHello!<|im_end|>\n" +
|
||||||
|
"<|im_start|>assistant\n<think>\n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "basic user message - no thinking",
|
||||||
|
msgs: []api.Message{
|
||||||
|
{Role: "user", Content: "Hello!"},
|
||||||
|
},
|
||||||
|
thinkValue: nil,
|
||||||
|
expected: "<|im_start|>system\n<|im_end|>\n" +
|
||||||
|
"<|im_start|>user\nHello!<|im_end|>\n" +
|
||||||
|
"<|im_start|>assistant\n<think></think>",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with system message",
|
||||||
|
msgs: []api.Message{
|
||||||
|
{Role: "system", Content: "You are a helpful assistant."},
|
||||||
|
{Role: "user", Content: "Hello!"},
|
||||||
|
},
|
||||||
|
thinkValue: &api.ThinkValue{Value: true},
|
||||||
|
expected: "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" +
|
||||||
|
"<|im_start|>user\nHello!<|im_end|>\n" +
|
||||||
|
"<|im_start|>assistant\n<think>\n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multi-turn conversation",
|
||||||
|
msgs: []api.Message{
|
||||||
|
{Role: "user", Content: "Hi"},
|
||||||
|
{Role: "assistant", Content: "Hello! How can I help?"},
|
||||||
|
{Role: "user", Content: "Tell me a joke"},
|
||||||
|
},
|
||||||
|
thinkValue: &api.ThinkValue{Value: true},
|
||||||
|
expected: "<|im_start|>system\n<|im_end|>\n" +
|
||||||
|
"<|im_start|>user\nHi<|im_end|>\n" +
|
||||||
|
"<|im_start|>assistant\n<think></think>Hello! How can I help?<|im_end|>\n" +
|
||||||
|
"<|im_start|>user\nTell me a joke<|im_end|>\n" +
|
||||||
|
"<|im_start|>assistant\n<think>\n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with tools",
|
||||||
|
msgs: []api.Message{
|
||||||
|
{Role: "user", Content: "What's the weather in Paris?"},
|
||||||
|
},
|
||||||
|
tools: []api.Tool{
|
||||||
|
{
|
||||||
|
Type: "function",
|
||||||
|
Function: api.ToolFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Description: "Get the current weather",
|
||||||
|
Parameters: api.ToolFunctionParameters{
|
||||||
|
Type: "object",
|
||||||
|
Required: []string{"city"},
|
||||||
|
Properties: map[string]api.ToolProperty{
|
||||||
|
"city": {Type: api.PropertyType{"string"}, Description: "The city name"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
thinkValue: &api.ThinkValue{Value: true},
|
||||||
|
expected: "<|im_start|>system\n" +
|
||||||
|
"# Tools\n\nYou have access to the following functions:\n\n<tools>\n" +
|
||||||
|
"<function>\n<name>get_weather</name>\n" +
|
||||||
|
"<description>Get the current weather</description>\n" +
|
||||||
|
"<parameters>\n" +
|
||||||
|
"<parameter>\n<name>city</name>\n<type>string</type>\n<description>The city name</description>\n</parameter>\n" +
|
||||||
|
"<required>[\"city\"]</required>\n" +
|
||||||
|
"</parameters>\n</function>\n</tools>\n\n" +
|
||||||
|
"If you choose to call a function ONLY reply in the following format with NO suffix:\n\n" +
|
||||||
|
"<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n" +
|
||||||
|
"<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n" +
|
||||||
|
"</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n" +
|
||||||
|
"- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n" +
|
||||||
|
"- Required parameters MUST be specified\n" +
|
||||||
|
"- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n" +
|
||||||
|
"- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n" +
|
||||||
|
"</IMPORTANT><|im_end|>\n" +
|
||||||
|
"<|im_start|>user\nWhat's the weather in Paris?<|im_end|>\n" +
|
||||||
|
"<|im_start|>assistant\n<think>\n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tool call with response",
|
||||||
|
msgs: []api.Message{
|
||||||
|
{Role: "user", Content: "What's the weather in Paris?"},
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
ToolCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Arguments: map[string]any{"city": "Paris"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{Role: "tool", Content: "Sunny, 72F"},
|
||||||
|
},
|
||||||
|
tools: []api.Tool{
|
||||||
|
{
|
||||||
|
Type: "function",
|
||||||
|
Function: api.ToolFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Description: "Get the current weather",
|
||||||
|
Parameters: api.ToolFunctionParameters{
|
||||||
|
Type: "object",
|
||||||
|
Required: []string{"city"},
|
||||||
|
Properties: map[string]api.ToolProperty{
|
||||||
|
"city": {Type: api.PropertyType{"string"}, Description: "The city name"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
thinkValue: &api.ThinkValue{Value: true},
|
||||||
|
expected: "<|im_start|>system\n" +
|
||||||
|
"# Tools\n\nYou have access to the following functions:\n\n<tools>\n" +
|
||||||
|
"<function>\n<name>get_weather</name>\n" +
|
||||||
|
"<description>Get the current weather</description>\n" +
|
||||||
|
"<parameters>\n" +
|
||||||
|
"<parameter>\n<name>city</name>\n<type>string</type>\n<description>The city name</description>\n</parameter>\n" +
|
||||||
|
"<required>[\"city\"]</required>\n" +
|
||||||
|
"</parameters>\n</function>\n</tools>\n\n" +
|
||||||
|
"If you choose to call a function ONLY reply in the following format with NO suffix:\n\n" +
|
||||||
|
"<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n" +
|
||||||
|
"<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n" +
|
||||||
|
"</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n" +
|
||||||
|
"- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n" +
|
||||||
|
"- Required parameters MUST be specified\n" +
|
||||||
|
"- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n" +
|
||||||
|
"- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n" +
|
||||||
|
"</IMPORTANT><|im_end|>\n" +
|
||||||
|
"<|im_start|>user\nWhat's the weather in Paris?<|im_end|>\n" +
|
||||||
|
"<|im_start|>assistant\n<think></think>\n" +
|
||||||
|
"<tool_call>\n<function=get_weather>\n<parameter=city>\nParis\n</parameter>\n</function>\n</tool_call>\n<|im_end|>\n" +
|
||||||
|
"<|im_start|>user\n<tool_response>\nSunny, 72F\n</tool_response>\n<|im_end|>\n" +
|
||||||
|
"<|im_start|>assistant\n<think>\n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "assistant with content and tool call",
|
||||||
|
msgs: []api.Message{
|
||||||
|
{Role: "user", Content: "What's the weather?"},
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: "Let me check that for you.",
|
||||||
|
ToolCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Arguments: map[string]any{"city": "Paris"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{Role: "tool", Content: "Sunny"},
|
||||||
|
},
|
||||||
|
tools: []api.Tool{
|
||||||
|
{
|
||||||
|
Type: "function",
|
||||||
|
Function: api.ToolFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Parameters: api.ToolFunctionParameters{
|
||||||
|
Type: "object",
|
||||||
|
Properties: map[string]api.ToolProperty{
|
||||||
|
"city": {Type: api.PropertyType{"string"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
thinkValue: &api.ThinkValue{Value: true},
|
||||||
|
expected: "<|im_start|>system\n" +
|
||||||
|
"# Tools\n\nYou have access to the following functions:\n\n<tools>\n" +
|
||||||
|
"<function>\n<name>get_weather</name>\n" +
|
||||||
|
"<parameters>\n" +
|
||||||
|
"<parameter>\n<name>city</name>\n<type>string</type>\n</parameter>\n" +
|
||||||
|
"</parameters>\n</function>\n</tools>\n\n" +
|
||||||
|
"If you choose to call a function ONLY reply in the following format with NO suffix:\n\n" +
|
||||||
|
"<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n" +
|
||||||
|
"<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n" +
|
||||||
|
"</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n" +
|
||||||
|
"- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n" +
|
||||||
|
"- Required parameters MUST be specified\n" +
|
||||||
|
"- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n" +
|
||||||
|
"- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n" +
|
||||||
|
"</IMPORTANT><|im_end|>\n" +
|
||||||
|
"<|im_start|>user\nWhat's the weather?<|im_end|>\n" +
|
||||||
|
"<|im_start|>assistant\n<think></think>Let me check that for you.\n" +
|
||||||
|
"<tool_call>\n<function=get_weather>\n<parameter=city>\nParis\n</parameter>\n</function>\n</tool_call>\n<|im_end|>\n" +
|
||||||
|
"<|im_start|>user\n<tool_response>\nSunny\n</tool_response>\n<|im_end|>\n" +
|
||||||
|
"<|im_start|>assistant\n<think>\n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "thinking in history is truncated",
|
||||||
|
msgs: []api.Message{
|
||||||
|
{Role: "user", Content: "Hi"},
|
||||||
|
{Role: "assistant", Content: "Hello!", Thinking: "Let me think about this..."},
|
||||||
|
{Role: "user", Content: "How are you?"},
|
||||||
|
},
|
||||||
|
thinkValue: &api.ThinkValue{Value: true},
|
||||||
|
expected: "<|im_start|>system\n<|im_end|>\n" +
|
||||||
|
"<|im_start|>user\nHi<|im_end|>\n" +
|
||||||
|
"<|im_start|>assistant\n<think></think>Hello!<|im_end|>\n" +
|
||||||
|
"<|im_start|>user\nHow are you?<|im_end|>\n" +
|
||||||
|
"<|im_start|>assistant\n<think>\n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "parallel tool calls",
|
||||||
|
msgs: []api.Message{
|
||||||
|
{Role: "user", Content: "Weather in Paris and London?"},
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
ToolCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Arguments: map[string]any{"city": "Paris"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Arguments: map[string]any{"city": "London"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{Role: "tool", Content: "Sunny"},
|
||||||
|
{Role: "tool", Content: "Rainy"},
|
||||||
|
},
|
||||||
|
tools: []api.Tool{
|
||||||
|
{
|
||||||
|
Type: "function",
|
||||||
|
Function: api.ToolFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Parameters: api.ToolFunctionParameters{
|
||||||
|
Type: "object",
|
||||||
|
Properties: map[string]api.ToolProperty{
|
||||||
|
"city": {Type: api.PropertyType{"string"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
thinkValue: &api.ThinkValue{Value: true},
|
||||||
|
expected: "<|im_start|>system\n" +
|
||||||
|
"# Tools\n\nYou have access to the following functions:\n\n<tools>\n" +
|
||||||
|
"<function>\n<name>get_weather</name>\n" +
|
||||||
|
"<parameters>\n" +
|
||||||
|
"<parameter>\n<name>city</name>\n<type>string</type>\n</parameter>\n" +
|
||||||
|
"</parameters>\n</function>\n</tools>\n\n" +
|
||||||
|
"If you choose to call a function ONLY reply in the following format with NO suffix:\n\n" +
|
||||||
|
"<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n" +
|
||||||
|
"<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n" +
|
||||||
|
"</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n" +
|
||||||
|
"- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n" +
|
||||||
|
"- Required parameters MUST be specified\n" +
|
||||||
|
"- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n" +
|
||||||
|
"- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n" +
|
||||||
|
"</IMPORTANT><|im_end|>\n" +
|
||||||
|
"<|im_start|>user\nWeather in Paris and London?<|im_end|>\n" +
|
||||||
|
"<|im_start|>assistant\n<think></think>\n" +
|
||||||
|
"<tool_call>\n<function=get_weather>\n<parameter=city>\nParis\n</parameter>\n</function>\n</tool_call>\n" +
|
||||||
|
"<tool_call>\n<function=get_weather>\n<parameter=city>\nLondon\n</parameter>\n</function>\n</tool_call>\n<|im_end|>\n" +
|
||||||
|
"<|im_start|>user\n<tool_response>\nSunny\n</tool_response>\n<tool_response>\nRainy\n</tool_response>\n<|im_end|>\n" +
|
||||||
|
"<|im_start|>assistant\n<think>\n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "thinking disabled when user doesn't request it",
|
||||||
|
msgs: []api.Message{
|
||||||
|
{Role: "user", Content: "Hello!"},
|
||||||
|
},
|
||||||
|
thinkValue: nil,
|
||||||
|
expected: "<|im_start|>system\n<|im_end|>\n" +
|
||||||
|
"<|im_start|>user\nHello!<|im_end|>\n" +
|
||||||
|
"<|im_start|>assistant\n<think></think>",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "complex message history with thinking, tools, tool calls, tool results and content",
|
||||||
|
msgs: []api.Message{
|
||||||
|
{Role: "user", Content: "What's the weather in Paris and London? Also, what's 2+2?"},
|
||||||
|
{Role: "assistant", Content: "", Thinking: "I need to check the weather for both cities and calculate 2+2. Let me start with the weather calls.", ToolCalls: []api.ToolCall{
|
||||||
|
{Function: api.ToolCallFunction{Name: "get_weather", Arguments: api.ToolCallFunctionArguments{"city": "Paris"}}},
|
||||||
|
{Function: api.ToolCallFunction{Name: "get_weather", Arguments: api.ToolCallFunctionArguments{"city": "London"}}},
|
||||||
|
}},
|
||||||
|
{Role: "tool", Content: "Sunny, 22°C", ToolCallID: "call1"},
|
||||||
|
{Role: "tool", Content: "Rainy, 15°C", ToolCallID: "call2"},
|
||||||
|
{Role: "assistant", Content: "", Thinking: "Now I have the weather data. Let me calculate 2+2.", ToolCalls: []api.ToolCall{
|
||||||
|
{Function: api.ToolCallFunction{Name: "calculate", Arguments: api.ToolCallFunctionArguments{"expression": "2+2"}}},
|
||||||
|
}},
|
||||||
|
{Role: "tool", Content: "4", ToolCallID: "call3"},
|
||||||
|
{Role: "assistant", Content: "Based on the weather data, Paris is sunny at 22°C and London is rainy at 15°C. Also, 2+2 equals 4.", Thinking: "Perfect! I have all the information needed to provide a complete answer."},
|
||||||
|
},
|
||||||
|
tools: []api.Tool{
|
||||||
|
{
|
||||||
|
Type: "function",
|
||||||
|
Function: api.ToolFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Parameters: api.ToolFunctionParameters{
|
||||||
|
Type: "object",
|
||||||
|
Properties: map[string]api.ToolProperty{
|
||||||
|
"city": {Type: api.PropertyType{"string"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Type: "function",
|
||||||
|
Function: api.ToolFunction{
|
||||||
|
Name: "calculate",
|
||||||
|
Parameters: api.ToolFunctionParameters{
|
||||||
|
Type: "object",
|
||||||
|
Properties: map[string]api.ToolProperty{
|
||||||
|
"expression": {Type: api.PropertyType{"string"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
thinkValue: &api.ThinkValue{Value: true},
|
||||||
|
expected: "<|im_start|>system\n" +
|
||||||
|
"# Tools\n\nYou have access to the following functions:\n\n<tools>\n" +
|
||||||
|
"<function>\n<name>get_weather</name>\n" +
|
||||||
|
"<parameters>\n" +
|
||||||
|
"<parameter>\n<name>city</name>\n<type>string</type>\n</parameter>\n" +
|
||||||
|
"</parameters>\n</function>\n" +
|
||||||
|
"<function>\n<name>calculate</name>\n" +
|
||||||
|
"<parameters>\n" +
|
||||||
|
"<parameter>\n<name>expression</name>\n<type>string</type>\n</parameter>\n" +
|
||||||
|
"</parameters>\n</function>\n</tools>\n\n" +
|
||||||
|
"If you choose to call a function ONLY reply in the following format with NO suffix:\n\n" +
|
||||||
|
"<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n" +
|
||||||
|
"<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n" +
|
||||||
|
"</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n" +
|
||||||
|
"- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n" +
|
||||||
|
"- Required parameters MUST be specified\n" +
|
||||||
|
"- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n" +
|
||||||
|
"- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n" +
|
||||||
|
"</IMPORTANT><|im_end|>\n" +
|
||||||
|
"<|im_start|>user\nWhat's the weather in Paris and London? Also, what's 2+2?<|im_end|>\n" +
|
||||||
|
"<|im_start|>assistant\n" +
|
||||||
|
"<think>\nI need to check the weather for both cities and calculate 2+2. Let me start with the weather calls.\n</think>\n" +
|
||||||
|
"<tool_call>\n<function=get_weather>\n<parameter=city>\nParis\n</parameter>\n</function>\n</tool_call>\n" +
|
||||||
|
"<tool_call>\n<function=get_weather>\n<parameter=city>\nLondon\n</parameter>\n</function>\n</tool_call>\n<|im_end|>\n" +
|
||||||
|
"<|im_start|>user\n<tool_response>\nSunny, 22°C\n</tool_response>\n<tool_response>\nRainy, 15°C\n</tool_response>\n<|im_end|>\n" +
|
||||||
|
"<|im_start|>assistant\n" +
|
||||||
|
"<think>\nNow I have the weather data. Let me calculate 2+2.\n</think>\n" +
|
||||||
|
"<tool_call>\n<function=calculate>\n<parameter=expression>\n2+2\n</parameter>\n</function>\n</tool_call>\n<|im_end|>\n" +
|
||||||
|
"<|im_start|>user\n<tool_response>\n4\n</tool_response>\n<|im_end|>\n" +
|
||||||
|
"<|im_start|>assistant\n" +
|
||||||
|
"<think>\nPerfect! I have all the information needed to provide a complete answer.\n</think>\n" +
|
||||||
|
"Based on the weather data, Paris is sunny at 22°C and London is rainy at 15°C. Also, 2+2 equals 4.<|im_end|>\n" +
|
||||||
|
"<|im_start|>assistant\n<think>\n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty messages list",
|
||||||
|
msgs: []api.Message{},
|
||||||
|
thinkValue: nil,
|
||||||
|
expected: "<|im_start|>system\n<|im_end|>\n<|im_start|>assistant\n<think></think>",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tool result with JSON content",
|
||||||
|
msgs: []api.Message{
|
||||||
|
{Role: "user", Content: "Get user info"},
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
ToolCalls: []api.ToolCall{
|
||||||
|
{Function: api.ToolCallFunction{Name: "get_user", Arguments: map[string]any{"id": "123"}}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{Role: "tool", Content: `{"name": "John", "age": 30, "active": true}`},
|
||||||
|
},
|
||||||
|
tools: []api.Tool{
|
||||||
|
{
|
||||||
|
Type: "function",
|
||||||
|
Function: api.ToolFunction{
|
||||||
|
Name: "get_user",
|
||||||
|
Parameters: api.ToolFunctionParameters{
|
||||||
|
Type: "object",
|
||||||
|
Properties: map[string]api.ToolProperty{"id": {Type: api.PropertyType{"string"}}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
thinkValue: &api.ThinkValue{Value: true},
|
||||||
|
expected: "<|im_start|>system\n" +
|
||||||
|
"# Tools\n\nYou have access to the following functions:\n\n<tools>\n" +
|
||||||
|
"<function>\n<name>get_user</name>\n<parameters>\n" +
|
||||||
|
"<parameter>\n<name>id</name>\n<type>string</type>\n</parameter>\n" +
|
||||||
|
"</parameters>\n</function>\n</tools>\n\n" +
|
||||||
|
"If you choose to call a function ONLY reply in the following format with NO suffix:\n\n" +
|
||||||
|
"<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n" +
|
||||||
|
"<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n" +
|
||||||
|
"</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n" +
|
||||||
|
"- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n" +
|
||||||
|
"- Required parameters MUST be specified\n" +
|
||||||
|
"- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n" +
|
||||||
|
"- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n" +
|
||||||
|
"</IMPORTANT><|im_end|>\n" +
|
||||||
|
"<|im_start|>user\nGet user info<|im_end|>\n" +
|
||||||
|
"<|im_start|>assistant\n<think></think>\n" +
|
||||||
|
"<tool_call>\n<function=get_user>\n<parameter=id>\n123\n</parameter>\n</function>\n</tool_call>\n<|im_end|>\n" +
|
||||||
|
"<|im_start|>user\n<tool_response>\n{\"name\": \"John\", \"age\": 30, \"active\": true}\n</tool_response>\n<|im_end|>\n" +
|
||||||
|
"<|im_start|>assistant\n<think>\n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "assistant message with only thinking no content",
|
||||||
|
msgs: []api.Message{
|
||||||
|
{Role: "user", Content: "Think about this"},
|
||||||
|
{Role: "assistant", Thinking: "Deep thoughts here...", Content: ""},
|
||||||
|
{Role: "user", Content: "What did you think?"},
|
||||||
|
},
|
||||||
|
thinkValue: &api.ThinkValue{Value: true},
|
||||||
|
expected: "<|im_start|>system\n<|im_end|>\n" +
|
||||||
|
"<|im_start|>user\nThink about this<|im_end|>\n" +
|
||||||
|
"<|im_start|>assistant\n<think></think><|im_end|>\n" +
|
||||||
|
"<|im_start|>user\nWhat did you think?<|im_end|>\n" +
|
||||||
|
"<|im_start|>assistant\n<think>\n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tool call with complex nested argument",
|
||||||
|
msgs: []api.Message{
|
||||||
|
{Role: "user", Content: "Create data"},
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
ToolCalls: []api.ToolCall{
|
||||||
|
{Function: api.ToolCallFunction{
|
||||||
|
Name: "create",
|
||||||
|
Arguments: map[string]any{
|
||||||
|
"data": map[string]any{"nested": "value", "count": 42},
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{Role: "tool", Content: "Created"},
|
||||||
|
},
|
||||||
|
tools: []api.Tool{
|
||||||
|
{
|
||||||
|
Type: "function",
|
||||||
|
Function: api.ToolFunction{
|
||||||
|
Name: "create",
|
||||||
|
Parameters: api.ToolFunctionParameters{
|
||||||
|
Type: "object",
|
||||||
|
Properties: map[string]api.ToolProperty{"data": {Type: api.PropertyType{"object"}}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
thinkValue: &api.ThinkValue{Value: true},
|
||||||
|
expected: "<|im_start|>system\n" +
|
||||||
|
"# Tools\n\nYou have access to the following functions:\n\n<tools>\n" +
|
||||||
|
"<function>\n<name>create</name>\n<parameters>\n" +
|
||||||
|
"<parameter>\n<name>data</name>\n<type>object</type>\n</parameter>\n" +
|
||||||
|
"</parameters>\n</function>\n</tools>\n\n" +
|
||||||
|
"If you choose to call a function ONLY reply in the following format with NO suffix:\n\n" +
|
||||||
|
"<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n" +
|
||||||
|
"<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n" +
|
||||||
|
"</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n" +
|
||||||
|
"- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n" +
|
||||||
|
"- Required parameters MUST be specified\n" +
|
||||||
|
"- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n" +
|
||||||
|
"- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n" +
|
||||||
|
"</IMPORTANT><|im_end|>\n" +
|
||||||
|
"<|im_start|>user\nCreate data<|im_end|>\n" +
|
||||||
|
"<|im_start|>assistant\n<think></think>\n" +
|
||||||
|
"<tool_call>\n<function=create>\n<parameter=data>\n{\"count\":42,\"nested\":\"value\"}\n</parameter>\n</function>\n</tool_call>\n<|im_end|>\n" +
|
||||||
|
"<|im_start|>user\n<tool_response>\nCreated\n</tool_response>\n<|im_end|>\n" +
|
||||||
|
"<|im_start|>assistant\n<think>\n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "content explaining the format itself",
|
||||||
|
msgs: []api.Message{
|
||||||
|
{Role: "user", Content: "How do I format a tool call?"},
|
||||||
|
{Role: "assistant", Content: "To call a tool, use <tool_call> tags with <function=name> inside."},
|
||||||
|
{Role: "user", Content: "Thanks!"},
|
||||||
|
},
|
||||||
|
thinkValue: &api.ThinkValue{Value: true},
|
||||||
|
expected: "<|im_start|>system\n<|im_end|>\n" +
|
||||||
|
"<|im_start|>user\nHow do I format a tool call?<|im_end|>\n" +
|
||||||
|
"<|im_start|>assistant\n<think></think>To call a tool, use <tool_call> tags with <function=name> inside.<|im_end|>\n" +
|
||||||
|
"<|im_start|>user\nThanks!<|im_end|>\n" +
|
||||||
|
"<|im_start|>assistant\n<think>\n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unicode in content and tool args",
|
||||||
|
msgs: []api.Message{
|
||||||
|
{Role: "user", Content: "Translate 你好"},
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
ToolCalls: []api.ToolCall{
|
||||||
|
{Function: api.ToolCallFunction{Name: "translate", Arguments: map[string]any{"text": "你好"}}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{Role: "tool", Content: "Hello"},
|
||||||
|
},
|
||||||
|
tools: []api.Tool{
|
||||||
|
{
|
||||||
|
Type: "function",
|
||||||
|
Function: api.ToolFunction{
|
||||||
|
Name: "translate",
|
||||||
|
Parameters: api.ToolFunctionParameters{
|
||||||
|
Type: "object",
|
||||||
|
Properties: map[string]api.ToolProperty{
|
||||||
|
"text": {Type: api.PropertyType{"string"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
thinkValue: &api.ThinkValue{Value: true},
|
||||||
|
expected: "<|im_start|>system\n" +
|
||||||
|
"# Tools\n\nYou have access to the following functions:\n\n<tools>\n" +
|
||||||
|
"<function>\n<name>translate</name>\n<parameters>\n" +
|
||||||
|
"<parameter>\n<name>text</name>\n<type>string</type>\n</parameter>\n" +
|
||||||
|
"</parameters>\n</function>\n</tools>\n\n" +
|
||||||
|
"If you choose to call a function ONLY reply in the following format with NO suffix:\n\n" +
|
||||||
|
"<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n" +
|
||||||
|
"<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n" +
|
||||||
|
"</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n" +
|
||||||
|
"- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n" +
|
||||||
|
"- Required parameters MUST be specified\n" +
|
||||||
|
"- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n" +
|
||||||
|
"- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n" +
|
||||||
|
"</IMPORTANT><|im_end|>\n" +
|
||||||
|
"<|im_start|>user\nTranslate 你好<|im_end|>\n" +
|
||||||
|
"<|im_start|>assistant\n<think></think>\n" +
|
||||||
|
"<tool_call>\n<function=translate>\n<parameter=text>\n你好\n</parameter>\n</function>\n</tool_call>\n<|im_end|>\n" +
|
||||||
|
"<|im_start|>user\n<tool_response>\nHello\n</tool_response>\n<|im_end|>\n" +
|
||||||
|
"<|im_start|>assistant\n<think>\n",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
renderer := &Nemotron3NanoRenderer{}
|
||||||
|
rendered, err := renderer.Render(tt.msgs, tt.tools, tt.thinkValue)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(rendered, tt.expected); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -10,12 +10,15 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
olmo3DefaultSystemMessage = "You are a helpful function-calling AI assistant. "
|
olmo3DefaultSystemMessage = "You are a helpful function-calling AI assistant. "
|
||||||
olmo3NoFunctionsMessage = "You do not currently have access to any functions. "
|
olmo31DefaultSystemMessage = "You are Olmo, a helpful AI assistant built by Ai2. Your date cutoff is December 2024, and your model weights are available at https://huggingface.co/allenai. "
|
||||||
olmo3WithFunctionsMessage = "You are provided with function signatures within <functions></functions> XML tags. You may call one or more functions to assist with the user query. Output any function calls within <function_calls></function_calls> XML tags. Do not make assumptions about what values to plug into functions."
|
olmo3NoFunctionsMessage = "You do not currently have access to any functions. "
|
||||||
|
olmo3WithFunctionsMessage = "You are provided with function signatures within <functions></functions> XML tags. You may call one or more functions to assist with the user query. Output any function calls within <function_calls></function_calls> XML tags. Do not make assumptions about what values to plug into functions."
|
||||||
)
|
)
|
||||||
|
|
||||||
type Olmo3Renderer struct{}
|
type Olmo3Renderer struct {
|
||||||
|
UseExtendedSystemMessage bool
|
||||||
|
}
|
||||||
|
|
||||||
func (r *Olmo3Renderer) Render(messages []api.Message, tools []api.Tool, _ *api.ThinkValue) (string, error) {
|
func (r *Olmo3Renderer) Render(messages []api.Message, tools []api.Tool, _ *api.ThinkValue) (string, error) {
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
|
|
@ -51,7 +54,11 @@ func (r *Olmo3Renderer) Render(messages []api.Message, tools []api.Tool, _ *api.
|
||||||
} else {
|
} else {
|
||||||
// Default system message - single newline after "system"
|
// Default system message - single newline after "system"
|
||||||
sb.WriteString("<|im_start|>system\n")
|
sb.WriteString("<|im_start|>system\n")
|
||||||
sb.WriteString(olmo3DefaultSystemMessage)
|
if r.UseExtendedSystemMessage {
|
||||||
|
sb.WriteString(olmo31DefaultSystemMessage)
|
||||||
|
} else {
|
||||||
|
sb.WriteString(olmo3DefaultSystemMessage)
|
||||||
|
}
|
||||||
|
|
||||||
if len(tools) > 0 {
|
if len(tools) > 0 {
|
||||||
functionsJSON, err := marshalWithSpaces(tools)
|
functionsJSON, err := marshalWithSpaces(tools)
|
||||||
|
|
@ -140,7 +147,7 @@ func (r *Olmo3Renderer) Render(messages []api.Message, tools []api.Tool, _ *api.
|
||||||
}
|
}
|
||||||
|
|
||||||
if needsGenerationPrompt {
|
if needsGenerationPrompt {
|
||||||
sb.WriteString("<|im_start|>assistant\n\n")
|
sb.WriteString("<|im_start|>assistant\n")
|
||||||
}
|
}
|
||||||
|
|
||||||
return sb.String(), nil
|
return sb.String(), nil
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,7 @@ func TestOlmo3Renderer(t *testing.T) {
|
||||||
"You are a helpful function-calling AI assistant. You do not currently have access to any functions. <functions></functions><|im_end|>\n" +
|
"You are a helpful function-calling AI assistant. You do not currently have access to any functions. <functions></functions><|im_end|>\n" +
|
||||||
"<|im_start|>user\n" +
|
"<|im_start|>user\n" +
|
||||||
"Hello!<|im_end|>\n" +
|
"Hello!<|im_end|>\n" +
|
||||||
"<|im_start|>assistant\n\n",
|
"<|im_start|>assistant\n",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "with system message no tools",
|
name: "with system message no tools",
|
||||||
|
|
@ -36,7 +36,7 @@ func TestOlmo3Renderer(t *testing.T) {
|
||||||
"You are a helpful assistant.<|im_end|>\n" +
|
"You are a helpful assistant.<|im_end|>\n" +
|
||||||
"<|im_start|>user\n" +
|
"<|im_start|>user\n" +
|
||||||
"Hello!<|im_end|>\n" +
|
"Hello!<|im_end|>\n" +
|
||||||
"<|im_start|>assistant\n\n",
|
"<|im_start|>assistant\n",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "with system message and tools",
|
name: "with system message and tools",
|
||||||
|
|
@ -64,7 +64,7 @@ func TestOlmo3Renderer(t *testing.T) {
|
||||||
`You are a helpful assistant.<functions>[{"type": "function", "function": {"name": "get_weather", "description": "Get the current weather", "parameters": {"type": "object", "required": ["location"], "properties": {"location": {"type": "string", "description": "The city"}}}}}]</functions><|im_end|>` + "\n" +
|
`You are a helpful assistant.<functions>[{"type": "function", "function": {"name": "get_weather", "description": "Get the current weather", "parameters": {"type": "object", "required": ["location"], "properties": {"location": {"type": "string", "description": "The city"}}}}}]</functions><|im_end|>` + "\n" +
|
||||||
"<|im_start|>user\n" +
|
"<|im_start|>user\n" +
|
||||||
"What is the weather?<|im_end|>\n" +
|
"What is the weather?<|im_end|>\n" +
|
||||||
"<|im_start|>assistant\n\n",
|
"<|im_start|>assistant\n",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "default system with tools - includes function instruction",
|
name: "default system with tools - includes function instruction",
|
||||||
|
|
@ -93,7 +93,7 @@ func TestOlmo3Renderer(t *testing.T) {
|
||||||
`<functions>[{"type": "function", "function": {"name": "get_weather", "description": "Get the current weather", "parameters": {"type": "object", "required": ["location"], "properties": {"location": {"type": "string", "description": "The city"}}}}}]</functions><|im_end|>` + "\n" +
|
`<functions>[{"type": "function", "function": {"name": "get_weather", "description": "Get the current weather", "parameters": {"type": "object", "required": ["location"], "properties": {"location": {"type": "string", "description": "The city"}}}}}]</functions><|im_end|>` + "\n" +
|
||||||
"<|im_start|>user\n" +
|
"<|im_start|>user\n" +
|
||||||
"What is the weather?<|im_end|>\n" +
|
"What is the weather?<|im_end|>\n" +
|
||||||
"<|im_start|>assistant\n\n",
|
"<|im_start|>assistant\n",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "assistant with tool calls - function call syntax",
|
name: "assistant with tool calls - function call syntax",
|
||||||
|
|
@ -141,7 +141,7 @@ func TestOlmo3Renderer(t *testing.T) {
|
||||||
`Let me check the weather.<function_calls>get_weather(location="San Francisco")</function_calls><|im_end|>` + "\n" +
|
`Let me check the weather.<function_calls>get_weather(location="San Francisco")</function_calls><|im_end|>` + "\n" +
|
||||||
"<|im_start|>environment\n" +
|
"<|im_start|>environment\n" +
|
||||||
`{"temperature": 68}<|im_end|>` + "\n" +
|
`{"temperature": 68}<|im_end|>` + "\n" +
|
||||||
"<|im_start|>assistant\n\n",
|
"<|im_start|>assistant\n",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "multi-turn conversation",
|
name: "multi-turn conversation",
|
||||||
|
|
@ -159,7 +159,7 @@ func TestOlmo3Renderer(t *testing.T) {
|
||||||
"Hi there!<|im_end|>\n" +
|
"Hi there!<|im_end|>\n" +
|
||||||
"<|im_start|>user\n" +
|
"<|im_start|>user\n" +
|
||||||
"How are you?<|im_end|>\n" +
|
"How are you?<|im_end|>\n" +
|
||||||
"<|im_start|>assistant\n\n",
|
"<|im_start|>assistant\n",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "parallel tool calls - newline separated",
|
name: "parallel tool calls - newline separated",
|
||||||
|
|
@ -214,7 +214,7 @@ func TestOlmo3Renderer(t *testing.T) {
|
||||||
`{"temperature": 68}<|im_end|>` + "\n" +
|
`{"temperature": 68}<|im_end|>` + "\n" +
|
||||||
"<|im_start|>environment\n" +
|
"<|im_start|>environment\n" +
|
||||||
`{"temperature": 55}<|im_end|>` + "\n" +
|
`{"temperature": 55}<|im_end|>` + "\n" +
|
||||||
"<|im_start|>assistant\n\n",
|
"<|im_start|>assistant\n",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "tool call with multiple arguments",
|
name: "tool call with multiple arguments",
|
||||||
|
|
@ -259,7 +259,7 @@ func TestOlmo3Renderer(t *testing.T) {
|
||||||
"Book a flight<|im_end|>\n" +
|
"Book a flight<|im_end|>\n" +
|
||||||
"<|im_start|>assistant\n" +
|
"<|im_start|>assistant\n" +
|
||||||
`<function_calls>book_flight(from="SFO", to="NYC")</function_calls><|im_end|>` + "\n" +
|
`<function_calls>book_flight(from="SFO", to="NYC")</function_calls><|im_end|>` + "\n" +
|
||||||
"<|im_start|>assistant\n\n",
|
"<|im_start|>assistant\n",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "assistant prefill - no generation prompt",
|
name: "assistant prefill - no generation prompt",
|
||||||
|
|
|
||||||
|
|
@ -1,31 +1,31 @@
|
||||||
package renderers
|
package renderers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type Olmo3ThinkVariant int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
olmo3ThinkDefaultSystemMessage = "You are OLMo, a helpful function-calling AI assistant built by Ai2. Your date cutoff is November 2024, and your model weights are available at https://huggingface.co/allenai."
|
// Olmo3Think32B is for allenai/Olmo-3-32B-Think
|
||||||
olmo3ThinkNoFunctionsMessage = " You do not currently have access to any functions."
|
Olmo3Think32B Olmo3ThinkVariant = iota
|
||||||
|
// Olmo31Think is for allenai/Olmo-3-7B-Think and allenai/Olmo-3.1-32B-Think (includes model info)
|
||||||
|
Olmo31Think
|
||||||
)
|
)
|
||||||
|
|
||||||
type Olmo3ThinkRenderer struct{}
|
const (
|
||||||
|
olmo3ThinkFunctionsSuffix = " You do not currently have access to any functions. <functions></functions>"
|
||||||
|
olmo3Think32BSystemMessage = "You are a helpful AI assistant."
|
||||||
|
olmo31ThinkSystemMessage = "You are Olmo, a helpful AI assistant built by Ai2. Your date cutoff is December 2024, and your model weights are available at https://huggingface.co/allenai."
|
||||||
|
)
|
||||||
|
|
||||||
type olmo3ThinkToolCall struct {
|
type Olmo3ThinkRenderer struct {
|
||||||
ID string `json:"id,omitempty"`
|
Variant Olmo3ThinkVariant
|
||||||
Type string `json:"type,omitempty"`
|
|
||||||
Function olmo3ThinkToolCallFunc `json:"function"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type olmo3ThinkToolCallFunc struct {
|
func (r *Olmo3ThinkRenderer) Render(messages []api.Message, _ []api.Tool, _ *api.ThinkValue) (string, error) {
|
||||||
Name string `json:"name"`
|
|
||||||
Arguments string `json:"arguments"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *Olmo3ThinkRenderer) Render(messages []api.Message, tools []api.Tool, _ *api.ThinkValue) (string, error) {
|
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
|
|
||||||
var systemMessage *api.Message
|
var systemMessage *api.Message
|
||||||
|
|
@ -37,34 +37,31 @@ func (r *Olmo3ThinkRenderer) Render(messages []api.Message, tools []api.Tool, _
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
// Skip tool messages - Think models don't support tools
|
||||||
|
if message.Role == "tool" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
filteredMessages = append(filteredMessages, message)
|
filteredMessages = append(filteredMessages, message)
|
||||||
}
|
}
|
||||||
|
|
||||||
systemContent := olmo3ThinkDefaultSystemMessage
|
|
||||||
if systemMessage != nil {
|
|
||||||
systemContent = systemMessage.Content
|
|
||||||
}
|
|
||||||
|
|
||||||
sb.WriteString("<|im_start|>system\n")
|
sb.WriteString("<|im_start|>system\n")
|
||||||
sb.WriteString(systemContent)
|
|
||||||
|
|
||||||
if len(tools) > 0 {
|
if systemMessage != nil {
|
||||||
functionsJSON, err := marshalWithSpaces(tools)
|
sb.WriteString(systemMessage.Content)
|
||||||
if err != nil {
|
sb.WriteString(olmo3ThinkFunctionsSuffix)
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
sb.WriteString(" <functions>")
|
|
||||||
sb.WriteString(string(functionsJSON))
|
|
||||||
sb.WriteString("</functions>")
|
|
||||||
} else {
|
} else {
|
||||||
sb.WriteString(olmo3ThinkNoFunctionsMessage)
|
// Default system message varies by variant
|
||||||
sb.WriteString(" <functions></functions>")
|
switch r.Variant {
|
||||||
|
case Olmo3Think32B:
|
||||||
|
sb.WriteString(olmo3Think32BSystemMessage)
|
||||||
|
default: // Olmo3Think7B, Olmo31Think use same template - diverges from HF but confirmed difference from team
|
||||||
|
sb.WriteString(olmo31ThinkSystemMessage)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
sb.WriteString("<|im_end|>\n")
|
sb.WriteString("<|im_end|>\n")
|
||||||
|
|
||||||
for i, message := range filteredMessages {
|
for _, message := range filteredMessages {
|
||||||
lastMessage := i == len(filteredMessages)-1
|
|
||||||
|
|
||||||
switch message.Role {
|
switch message.Role {
|
||||||
case "user":
|
case "user":
|
||||||
sb.WriteString("<|im_start|>user\n")
|
sb.WriteString("<|im_start|>user\n")
|
||||||
|
|
@ -73,58 +70,15 @@ func (r *Olmo3ThinkRenderer) Render(messages []api.Message, tools []api.Tool, _
|
||||||
|
|
||||||
case "assistant":
|
case "assistant":
|
||||||
sb.WriteString("<|im_start|>assistant\n")
|
sb.WriteString("<|im_start|>assistant\n")
|
||||||
|
|
||||||
if message.Content != "" {
|
if message.Content != "" {
|
||||||
sb.WriteString(message.Content)
|
sb.WriteString(message.Content)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(message.ToolCalls) > 0 {
|
|
||||||
toolCalls := make([]olmo3ThinkToolCall, len(message.ToolCalls))
|
|
||||||
for j, tc := range message.ToolCalls {
|
|
||||||
argsJSON, err := json.Marshal(tc.Function.Arguments)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
toolCalls[j] = olmo3ThinkToolCall{
|
|
||||||
ID: tc.ID,
|
|
||||||
Type: "function",
|
|
||||||
Function: olmo3ThinkToolCallFunc{
|
|
||||||
Name: tc.Function.Name,
|
|
||||||
Arguments: string(argsJSON),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
toolCallsJSON, err := marshalWithSpaces(toolCalls)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
sb.WriteString("<function_calls>")
|
|
||||||
sb.WriteString(string(toolCallsJSON))
|
|
||||||
sb.WriteString("</function_calls>")
|
|
||||||
}
|
|
||||||
|
|
||||||
if !lastMessage {
|
|
||||||
sb.WriteString("<|im_end|>\n")
|
|
||||||
}
|
|
||||||
|
|
||||||
case "tool":
|
|
||||||
sb.WriteString("<|im_start|>environment\n")
|
|
||||||
sb.WriteString(message.Content)
|
|
||||||
sb.WriteString("<|im_end|>\n")
|
sb.WriteString("<|im_end|>\n")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
needsGenerationPrompt := true
|
// Always add generation prompt with <think> tag for thinking models
|
||||||
if len(filteredMessages) > 0 {
|
sb.WriteString("<|im_start|>assistant\n<think>")
|
||||||
lastMsg := filteredMessages[len(filteredMessages)-1]
|
|
||||||
if lastMsg.Role == "assistant" && len(lastMsg.ToolCalls) == 0 && lastMsg.Content != "" {
|
|
||||||
needsGenerationPrompt = false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if needsGenerationPrompt {
|
|
||||||
sb.WriteString("<|im_start|>assistant\n<think>")
|
|
||||||
}
|
|
||||||
|
|
||||||
return sb.String(), nil
|
return sb.String(), nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -11,24 +11,27 @@ import (
|
||||||
func TestOlmo3ThinkRenderer(t *testing.T) {
|
func TestOlmo3ThinkRenderer(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
variant Olmo3ThinkVariant
|
||||||
msgs []api.Message
|
msgs []api.Message
|
||||||
tools []api.Tool
|
tools []api.Tool
|
||||||
expected string
|
expected string
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "basic without system - adds default system",
|
name: "7b_basic_without_system",
|
||||||
|
variant: Olmo31Think,
|
||||||
msgs: []api.Message{
|
msgs: []api.Message{
|
||||||
{Role: "user", Content: "Hello!"},
|
{Role: "user", Content: "Hello!"},
|
||||||
},
|
},
|
||||||
expected: "<|im_start|>system\n" +
|
expected: "<|im_start|>system\n" +
|
||||||
"You are OLMo, a helpful function-calling AI assistant built by Ai2. Your date cutoff is November 2024, and your model weights are available at https://huggingface.co/allenai. You do not currently have access to any functions. <functions></functions><|im_end|>\n" +
|
"You are Olmo, a helpful AI assistant built by Ai2. Your date cutoff is December 2024, and your model weights are available at https://huggingface.co/allenai.<|im_end|>\n" +
|
||||||
"<|im_start|>user\n" +
|
"<|im_start|>user\n" +
|
||||||
"Hello!<|im_end|>\n" +
|
"Hello!<|im_end|>\n" +
|
||||||
"<|im_start|>assistant\n" +
|
"<|im_start|>assistant\n" +
|
||||||
"<think>",
|
"<think>",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "with system message no tools",
|
name: "7b_with_custom_system",
|
||||||
|
variant: Olmo31Think,
|
||||||
msgs: []api.Message{
|
msgs: []api.Message{
|
||||||
{Role: "system", Content: "You are a helpful assistant."},
|
{Role: "system", Content: "You are a helpful assistant."},
|
||||||
{Role: "user", Content: "Hello!"},
|
{Role: "user", Content: "Hello!"},
|
||||||
|
|
@ -41,9 +44,9 @@ func TestOlmo3ThinkRenderer(t *testing.T) {
|
||||||
"<think>",
|
"<think>",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "with system message and tools",
|
name: "7b_tools_ignored",
|
||||||
|
variant: Olmo31Think,
|
||||||
msgs: []api.Message{
|
msgs: []api.Message{
|
||||||
{Role: "system", Content: "You are a helpful assistant."},
|
|
||||||
{Role: "user", Content: "What is the weather?"},
|
{Role: "user", Content: "What is the weather?"},
|
||||||
},
|
},
|
||||||
tools: []api.Tool{
|
tools: []api.Tool{
|
||||||
|
|
@ -52,27 +55,20 @@ func TestOlmo3ThinkRenderer(t *testing.T) {
|
||||||
Function: api.ToolFunction{
|
Function: api.ToolFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Description: "Get the current weather",
|
Description: "Get the current weather",
|
||||||
Parameters: api.ToolFunctionParameters{
|
|
||||||
Type: "object",
|
|
||||||
Required: []string{"location"},
|
|
||||||
Properties: map[string]api.ToolProperty{
|
|
||||||
"location": {Type: api.PropertyType{"string"}, Description: "The city"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expected: "<|im_start|>system\n" +
|
expected: "<|im_start|>system\n" +
|
||||||
`You are a helpful assistant. <functions>[{"type": "function", "function": {"name": "get_weather", "description": "Get the current weather", "parameters": {"type": "object", "required": ["location"], "properties": {"location": {"type": "string", "description": "The city"}}}}}]</functions><|im_end|>` + "\n" +
|
"You are Olmo, a helpful AI assistant built by Ai2. Your date cutoff is December 2024, and your model weights are available at https://huggingface.co/allenai.<|im_end|>\n" +
|
||||||
"<|im_start|>user\n" +
|
"<|im_start|>user\n" +
|
||||||
"What is the weather?<|im_end|>\n" +
|
"What is the weather?<|im_end|>\n" +
|
||||||
"<|im_start|>assistant\n" +
|
"<|im_start|>assistant\n" +
|
||||||
"<think>",
|
"<think>",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "assistant with tool calls",
|
name: "7b_tool_calls_and_tool_messages_ignored",
|
||||||
|
variant: Olmo31Think,
|
||||||
msgs: []api.Message{
|
msgs: []api.Message{
|
||||||
{Role: "system", Content: "You are a helpful assistant."},
|
|
||||||
{Role: "user", Content: "What is the weather in SF?"},
|
{Role: "user", Content: "What is the weather in SF?"},
|
||||||
{
|
{
|
||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
|
|
@ -81,53 +77,33 @@ func TestOlmo3ThinkRenderer(t *testing.T) {
|
||||||
{
|
{
|
||||||
ID: "call_1",
|
ID: "call_1",
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: map[string]any{
|
Arguments: map[string]any{"location": "San Francisco"},
|
||||||
"location": "San Francisco",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{Role: "tool", Content: `{"temperature": 68}`, ToolName: "get_weather"},
|
|
||||||
},
|
|
||||||
tools: []api.Tool{
|
|
||||||
{
|
|
||||||
Type: "function",
|
|
||||||
Function: api.ToolFunction{
|
|
||||||
Name: "get_weather",
|
|
||||||
Description: "Get the current weather",
|
|
||||||
Parameters: api.ToolFunctionParameters{
|
|
||||||
Type: "object",
|
|
||||||
Required: []string{"location"},
|
|
||||||
Properties: map[string]api.ToolProperty{
|
|
||||||
"location": {Type: api.PropertyType{"string"}, Description: "The city"},
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{Role: "tool", Content: `{"temperature": 68}`},
|
||||||
},
|
},
|
||||||
expected: "<|im_start|>system\n" +
|
expected: "<|im_start|>system\n" +
|
||||||
`You are a helpful assistant. <functions>[{"type": "function", "function": {"name": "get_weather", "description": "Get the current weather", "parameters": {"type": "object", "required": ["location"], "properties": {"location": {"type": "string", "description": "The city"}}}}}]</functions><|im_end|>` + "\n" +
|
"You are Olmo, a helpful AI assistant built by Ai2. Your date cutoff is December 2024, and your model weights are available at https://huggingface.co/allenai.<|im_end|>\n" +
|
||||||
"<|im_start|>user\n" +
|
"<|im_start|>user\n" +
|
||||||
"What is the weather in SF?<|im_end|>\n" +
|
"What is the weather in SF?<|im_end|>\n" +
|
||||||
"<|im_start|>assistant\n" +
|
"<|im_start|>assistant\n" +
|
||||||
`Let me check the weather.<function_calls>[{"id": "call_1", "type": "function", "function": {"name": "get_weather", "arguments": "{\"location\":\"San Francisco\"}"}}]</function_calls><|im_end|>` + "\n" +
|
"Let me check the weather.<|im_end|>\n" +
|
||||||
"<|im_start|>environment\n" +
|
|
||||||
`{"temperature": 68}<|im_end|>` + "\n" +
|
|
||||||
"<|im_start|>assistant\n" +
|
"<|im_start|>assistant\n" +
|
||||||
"<think>",
|
"<think>",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "multi-turn conversation",
|
name: "7b_multi_turn_conversation",
|
||||||
|
variant: Olmo31Think,
|
||||||
msgs: []api.Message{
|
msgs: []api.Message{
|
||||||
{Role: "system", Content: "You are a helpful assistant."},
|
|
||||||
{Role: "user", Content: "Hello"},
|
{Role: "user", Content: "Hello"},
|
||||||
{Role: "assistant", Content: "Hi there!"},
|
{Role: "assistant", Content: "Hi there!"},
|
||||||
{Role: "user", Content: "How are you?"},
|
{Role: "user", Content: "How are you?"},
|
||||||
},
|
},
|
||||||
expected: "<|im_start|>system\n" +
|
expected: "<|im_start|>system\n" +
|
||||||
"You are a helpful assistant. You do not currently have access to any functions. <functions></functions><|im_end|>\n" +
|
"You are Olmo, a helpful AI assistant built by Ai2. Your date cutoff is December 2024, and your model weights are available at https://huggingface.co/allenai.<|im_end|>\n" +
|
||||||
"<|im_start|>user\n" +
|
"<|im_start|>user\n" +
|
||||||
"Hello<|im_end|>\n" +
|
"Hello<|im_end|>\n" +
|
||||||
"<|im_start|>assistant\n" +
|
"<|im_start|>assistant\n" +
|
||||||
|
|
@ -138,73 +114,56 @@ func TestOlmo3ThinkRenderer(t *testing.T) {
|
||||||
"<think>",
|
"<think>",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "parallel tool calls",
|
name: "32b_basic_without_system",
|
||||||
|
variant: Olmo3Think32B,
|
||||||
msgs: []api.Message{
|
msgs: []api.Message{
|
||||||
{Role: "user", Content: "Get weather in SF and NYC"},
|
{Role: "user", Content: "Hello!"},
|
||||||
{
|
|
||||||
Role: "assistant",
|
|
||||||
ToolCalls: []api.ToolCall{
|
|
||||||
{
|
|
||||||
ID: "call_1",
|
|
||||||
Function: api.ToolCallFunction{
|
|
||||||
Name: "get_weather",
|
|
||||||
Arguments: map[string]any{"location": "San Francisco"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ID: "call_2",
|
|
||||||
Function: api.ToolCallFunction{
|
|
||||||
Name: "get_weather",
|
|
||||||
Arguments: map[string]any{"location": "New York"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{Role: "tool", Content: `{"temperature": 68}`, ToolName: "get_weather"},
|
|
||||||
{Role: "tool", Content: `{"temperature": 55}`, ToolName: "get_weather"},
|
|
||||||
},
|
|
||||||
tools: []api.Tool{
|
|
||||||
{
|
|
||||||
Type: "function",
|
|
||||||
Function: api.ToolFunction{
|
|
||||||
Name: "get_weather",
|
|
||||||
Parameters: api.ToolFunctionParameters{
|
|
||||||
Type: "object",
|
|
||||||
Properties: map[string]api.ToolProperty{
|
|
||||||
"location": {Type: api.PropertyType{"string"}},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
expected: "<|im_start|>system\n" +
|
expected: "<|im_start|>system\n" +
|
||||||
`You are OLMo, a helpful function-calling AI assistant built by Ai2. Your date cutoff is November 2024, and your model weights are available at https://huggingface.co/allenai. <functions>[{"type": "function", "function": {"name": "get_weather", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}}}}]</functions><|im_end|>` + "\n" +
|
"You are a helpful AI assistant.<|im_end|>\n" +
|
||||||
"<|im_start|>user\n" +
|
"<|im_start|>user\n" +
|
||||||
"Get weather in SF and NYC<|im_end|>\n" +
|
"Hello!<|im_end|>\n" +
|
||||||
"<|im_start|>assistant\n" +
|
|
||||||
`<function_calls>[{"id": "call_1", "type": "function", "function": {"name": "get_weather", "arguments": "{\"location\":\"San Francisco\"}"}}, {"id": "call_2", "type": "function", "function": {"name": "get_weather", "arguments": "{\"location\":\"New York\"}"}}]</function_calls><|im_end|>` + "\n" +
|
|
||||||
"<|im_start|>environment\n" +
|
|
||||||
`{"temperature": 68}<|im_end|>` + "\n" +
|
|
||||||
"<|im_start|>environment\n" +
|
|
||||||
`{"temperature": 55}<|im_end|>` + "\n" +
|
|
||||||
"<|im_start|>assistant\n" +
|
"<|im_start|>assistant\n" +
|
||||||
"<think>",
|
"<think>",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "assistant message only content no tool calls",
|
name: "32b_with_custom_system_gets_suffix",
|
||||||
|
variant: Olmo3Think32B,
|
||||||
msgs: []api.Message{
|
msgs: []api.Message{
|
||||||
{Role: "user", Content: "Tell me a joke"},
|
{Role: "system", Content: "You are a helpful assistant."},
|
||||||
{Role: "assistant", Content: "Why did the chicken cross the road?"},
|
{Role: "user", Content: "Hello!"},
|
||||||
{Role: "user", Content: "I don't know, why?"},
|
|
||||||
},
|
},
|
||||||
expected: "<|im_start|>system\n" +
|
expected: "<|im_start|>system\n" +
|
||||||
"You are OLMo, a helpful function-calling AI assistant built by Ai2. Your date cutoff is November 2024, and your model weights are available at https://huggingface.co/allenai. You do not currently have access to any functions. <functions></functions><|im_end|>\n" +
|
"You are a helpful assistant. You do not currently have access to any functions. <functions></functions><|im_end|>\n" +
|
||||||
"<|im_start|>user\n" +
|
"<|im_start|>user\n" +
|
||||||
"Tell me a joke<|im_end|>\n" +
|
"Hello!<|im_end|>\n" +
|
||||||
"<|im_start|>assistant\n" +
|
"<|im_start|>assistant\n" +
|
||||||
"Why did the chicken cross the road?<|im_end|>\n" +
|
"<think>",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "31_basic_without_system",
|
||||||
|
variant: Olmo31Think,
|
||||||
|
msgs: []api.Message{
|
||||||
|
{Role: "user", Content: "Hello!"},
|
||||||
|
},
|
||||||
|
expected: "<|im_start|>system\n" +
|
||||||
|
"You are Olmo, a helpful AI assistant built by Ai2. Your date cutoff is December 2024, and your model weights are available at https://huggingface.co/allenai.<|im_end|>\n" +
|
||||||
"<|im_start|>user\n" +
|
"<|im_start|>user\n" +
|
||||||
"I don't know, why?<|im_end|>\n" +
|
"Hello!<|im_end|>\n" +
|
||||||
|
"<|im_start|>assistant\n" +
|
||||||
|
"<think>",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "31_with_custom_system_gets_suffix",
|
||||||
|
variant: Olmo31Think,
|
||||||
|
msgs: []api.Message{
|
||||||
|
{Role: "system", Content: "You are a helpful assistant."},
|
||||||
|
{Role: "user", Content: "Hello!"},
|
||||||
|
},
|
||||||
|
expected: "<|im_start|>system\n" +
|
||||||
|
"You are a helpful assistant. You do not currently have access to any functions. <functions></functions><|im_end|>\n" +
|
||||||
|
"<|im_start|>user\n" +
|
||||||
|
"Hello!<|im_end|>\n" +
|
||||||
"<|im_start|>assistant\n" +
|
"<|im_start|>assistant\n" +
|
||||||
"<think>",
|
"<think>",
|
||||||
},
|
},
|
||||||
|
|
@ -212,7 +171,7 @@ func TestOlmo3ThinkRenderer(t *testing.T) {
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
rendered, err := (&Olmo3ThinkRenderer{}).Render(tt.msgs, tt.tools, nil)
|
rendered, err := (&Olmo3ThinkRenderer{Variant: tt.variant}).Render(tt.msgs, tt.tools, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -59,12 +59,25 @@ func rendererForName(name string) Renderer {
|
||||||
case "cogito":
|
case "cogito":
|
||||||
renderer := &CogitoRenderer{isThinking: true}
|
renderer := &CogitoRenderer{isThinking: true}
|
||||||
return renderer
|
return renderer
|
||||||
|
case "deepseek3.1":
|
||||||
|
renderer := &DeepSeek3Renderer{IsThinking: true, Variant: Deepseek31}
|
||||||
|
return renderer
|
||||||
case "olmo3":
|
case "olmo3":
|
||||||
renderer := &Olmo3Renderer{}
|
renderer := &Olmo3Renderer{UseExtendedSystemMessage: false}
|
||||||
|
return renderer
|
||||||
|
case "olmo3.1":
|
||||||
|
renderer := &Olmo3Renderer{UseExtendedSystemMessage: true}
|
||||||
return renderer
|
return renderer
|
||||||
case "olmo3-think":
|
case "olmo3-think":
|
||||||
renderer := &Olmo3ThinkRenderer{}
|
// Used for Olmo-3-7B-Think and Olmo-3.1-32B-Think (same template)
|
||||||
|
renderer := &Olmo3ThinkRenderer{Variant: Olmo31Think}
|
||||||
return renderer
|
return renderer
|
||||||
|
case "olmo3-32b-think":
|
||||||
|
// Used for Olmo-3-32B-Think
|
||||||
|
renderer := &Olmo3ThinkRenderer{Variant: Olmo3Think32B}
|
||||||
|
return renderer
|
||||||
|
case "nemotron-3-nano":
|
||||||
|
return &Nemotron3NanoRenderer{}
|
||||||
default:
|
default:
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -42,10 +42,10 @@ var (
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *Server) CreateHandler(c *gin.Context) {
|
func (s *Server) CreateHandler(c *gin.Context) {
|
||||||
config := &ConfigV2{
|
config := &model.ConfigV2{
|
||||||
OS: "linux",
|
OS: "linux",
|
||||||
Architecture: "amd64",
|
Architecture: "amd64",
|
||||||
RootFS: RootFS{
|
RootFS: model.RootFS{
|
||||||
Type: "layers",
|
Type: "layers",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
@ -126,7 +126,7 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
||||||
configPath, pErr := GetBlobsPath(manifest.Config.Digest)
|
configPath, pErr := GetBlobsPath(manifest.Config.Digest)
|
||||||
if pErr == nil {
|
if pErr == nil {
|
||||||
if cfgFile, fErr := os.Open(configPath); fErr == nil {
|
if cfgFile, fErr := os.Open(configPath); fErr == nil {
|
||||||
var baseConfig ConfigV2
|
var baseConfig model.ConfigV2
|
||||||
if decErr := json.NewDecoder(cfgFile).Decode(&baseConfig); decErr == nil {
|
if decErr := json.NewDecoder(cfgFile).Decode(&baseConfig); decErr == nil {
|
||||||
if config.Renderer == "" {
|
if config.Renderer == "" {
|
||||||
config.Renderer = baseConfig.Renderer
|
config.Renderer = baseConfig.Renderer
|
||||||
|
|
@ -459,7 +459,7 @@ func kvFromLayers(baseLayers []*layerGGML) (ggml.KV, error) {
|
||||||
return ggml.KV{}, fmt.Errorf("no base model was found")
|
return ggml.KV{}, fmt.Errorf("no base model was found")
|
||||||
}
|
}
|
||||||
|
|
||||||
func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML, config *ConfigV2, fn func(resp api.ProgressResponse)) (err error) {
|
func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML, config *model.ConfigV2, fn func(resp api.ProgressResponse)) (err error) {
|
||||||
var layers []Layer
|
var layers []Layer
|
||||||
for _, layer := range baseLayers {
|
for _, layer := range baseLayers {
|
||||||
if layer.GGML != nil {
|
if layer.GGML != nil {
|
||||||
|
|
@ -789,7 +789,7 @@ func setMessages(layers []Layer, m []api.Message) ([]Layer, error) {
|
||||||
return layers, nil
|
return layers, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func createConfigLayer(layers []Layer, config ConfigV2) (*Layer, error) {
|
func createConfigLayer(layers []Layer, config model.ConfigV2) (*Layer, error) {
|
||||||
digests := make([]string, len(layers))
|
digests := make([]string, len(layers))
|
||||||
for i, layer := range layers {
|
for i, layer := range layers {
|
||||||
digests[i] = layer.Digest
|
digests[i] = layer.Digest
|
||||||
|
|
|
||||||
|
|
@ -54,7 +54,7 @@ type registryOptions struct {
|
||||||
|
|
||||||
type Model struct {
|
type Model struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Config ConfigV2
|
Config model.ConfigV2
|
||||||
ShortName string
|
ShortName string
|
||||||
ModelPath string
|
ModelPath string
|
||||||
ParentModel string
|
ParentModel string
|
||||||
|
|
@ -266,35 +266,6 @@ func (m *Model) String() string {
|
||||||
return modelfile.String()
|
return modelfile.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
type ConfigV2 struct {
|
|
||||||
ModelFormat string `json:"model_format"`
|
|
||||||
ModelFamily string `json:"model_family"`
|
|
||||||
ModelFamilies []string `json:"model_families"`
|
|
||||||
ModelType string `json:"model_type"` // shown as Parameter Size
|
|
||||||
FileType string `json:"file_type"` // shown as Quantization Level
|
|
||||||
Renderer string `json:"renderer,omitempty"`
|
|
||||||
Parser string `json:"parser,omitempty"`
|
|
||||||
|
|
||||||
RemoteHost string `json:"remote_host,omitempty"`
|
|
||||||
RemoteModel string `json:"remote_model,omitempty"`
|
|
||||||
|
|
||||||
// used for remotes
|
|
||||||
Capabilities []string `json:"capabilities,omitempty"`
|
|
||||||
ContextLen int `json:"context_length,omitempty"`
|
|
||||||
EmbedLen int `json:"embedding_length,omitempty"`
|
|
||||||
BaseName string `json:"base_name,omitempty"`
|
|
||||||
|
|
||||||
// required by spec
|
|
||||||
Architecture string `json:"architecture"`
|
|
||||||
OS string `json:"os"`
|
|
||||||
RootFS RootFS `json:"rootfs"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type RootFS struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
DiffIDs []string `json:"diff_ids"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetManifest(mp ModelPath) (*Manifest, string, error) {
|
func GetManifest(mp ModelPath) (*Manifest, string, error) {
|
||||||
fp, err := mp.GetManifestPath()
|
fp, err := mp.GetManifestPath()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
||||||
|
|
@ -1223,7 +1223,7 @@ func (s *Server) ListHandler(c *gin.Context) {
|
||||||
|
|
||||||
models := []api.ListModelResponse{}
|
models := []api.ListModelResponse{}
|
||||||
for n, m := range ms {
|
for n, m := range ms {
|
||||||
var cf ConfigV2
|
var cf model.ConfigV2
|
||||||
|
|
||||||
if m.Config.Digest != "" {
|
if m.Config.Digest != "" {
|
||||||
f, err := m.Config.Open()
|
f, err := m.Config.Open()
|
||||||
|
|
|
||||||
|
|
@ -241,7 +241,7 @@ func TestCreateFromModelInheritsRendererParser(t *testing.T) {
|
||||||
}
|
}
|
||||||
defer cfgFile.Close()
|
defer cfgFile.Close()
|
||||||
|
|
||||||
var cfg ConfigV2
|
var cfg model.ConfigV2
|
||||||
if err := json.NewDecoder(cfgFile).Decode(&cfg); err != nil {
|
if err := json.NewDecoder(cfgFile).Decode(&cfg); err != nil {
|
||||||
t.Fatalf("decode config: %v", err)
|
t.Fatalf("decode config: %v", err)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -89,7 +89,7 @@ func TestDeleteDuplicateLayers(t *testing.T) {
|
||||||
n := model.ParseName("test")
|
n := model.ParseName("test")
|
||||||
|
|
||||||
var b bytes.Buffer
|
var b bytes.Buffer
|
||||||
if err := json.NewEncoder(&b).Encode(&ConfigV2{}); err != nil {
|
if err := json.NewEncoder(&b).Encode(&model.ConfigV2{}); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -126,10 +126,10 @@ func TestRoutes(t *testing.T) {
|
||||||
t.Fatalf("failed to create model: %v", err)
|
t.Fatalf("failed to create model: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
config := &ConfigV2{
|
config := &model.ConfigV2{
|
||||||
OS: "linux",
|
OS: "linux",
|
||||||
Architecture: "amd64",
|
Architecture: "amd64",
|
||||||
RootFS: RootFS{
|
RootFS: model.RootFS{
|
||||||
Type: "layers",
|
Type: "layers",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
@ -775,7 +775,7 @@ func TestFilterThinkTags(t *testing.T) {
|
||||||
{Role: "user", Content: "What is the answer?"},
|
{Role: "user", Content: "What is the answer?"},
|
||||||
},
|
},
|
||||||
model: &Model{
|
model: &Model{
|
||||||
Config: ConfigV2{
|
Config: model.ConfigV2{
|
||||||
ModelFamily: "qwen3",
|
ModelFamily: "qwen3",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -793,7 +793,7 @@ func TestFilterThinkTags(t *testing.T) {
|
||||||
{Role: "user", Content: "What is the answer?"},
|
{Role: "user", Content: "What is the answer?"},
|
||||||
},
|
},
|
||||||
model: &Model{
|
model: &Model{
|
||||||
Config: ConfigV2{
|
Config: model.ConfigV2{
|
||||||
ModelFamily: "qwen3",
|
ModelFamily: "qwen3",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -815,7 +815,7 @@ func TestFilterThinkTags(t *testing.T) {
|
||||||
{Role: "assistant", Content: "<think>thinking yet again</think>hjk"},
|
{Role: "assistant", Content: "<think>thinking yet again</think>hjk"},
|
||||||
},
|
},
|
||||||
model: &Model{
|
model: &Model{
|
||||||
Config: ConfigV2{
|
Config: model.ConfigV2{
|
||||||
ModelFamily: "qwen3",
|
ModelFamily: "qwen3",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -833,7 +833,7 @@ func TestFilterThinkTags(t *testing.T) {
|
||||||
{Role: "user", Content: "What is the answer?"},
|
{Role: "user", Content: "What is the answer?"},
|
||||||
},
|
},
|
||||||
model: &Model{
|
model: &Model{
|
||||||
Config: ConfigV2{
|
Config: model.ConfigV2{
|
||||||
ModelFamily: "llama3",
|
ModelFamily: "llama3",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -853,7 +853,7 @@ func TestFilterThinkTags(t *testing.T) {
|
||||||
model: &Model{
|
model: &Model{
|
||||||
Name: "registry.ollama.ai/library/deepseek-r1:latest",
|
Name: "registry.ollama.ai/library/deepseek-r1:latest",
|
||||||
ShortName: "deepseek-r1:7b",
|
ShortName: "deepseek-r1:7b",
|
||||||
Config: ConfigV2{},
|
Config: model.ConfigV2{},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,32 @@
|
||||||
|
package model
|
||||||
|
|
||||||
|
// ConfigV2 represents the configuration metadata for a model.
|
||||||
|
type ConfigV2 struct {
|
||||||
|
ModelFormat string `json:"model_format"`
|
||||||
|
ModelFamily string `json:"model_family"`
|
||||||
|
ModelFamilies []string `json:"model_families"`
|
||||||
|
ModelType string `json:"model_type"` // shown as Parameter Size
|
||||||
|
FileType string `json:"file_type"` // shown as Quantization Level
|
||||||
|
Renderer string `json:"renderer,omitempty"`
|
||||||
|
Parser string `json:"parser,omitempty"`
|
||||||
|
|
||||||
|
RemoteHost string `json:"remote_host,omitempty"`
|
||||||
|
RemoteModel string `json:"remote_model,omitempty"`
|
||||||
|
|
||||||
|
// used for remotes
|
||||||
|
Capabilities []string `json:"capabilities,omitempty"`
|
||||||
|
ContextLen int `json:"context_length,omitempty"`
|
||||||
|
EmbedLen int `json:"embedding_length,omitempty"`
|
||||||
|
BaseName string `json:"base_name,omitempty"`
|
||||||
|
|
||||||
|
// required by spec
|
||||||
|
Architecture string `json:"architecture"`
|
||||||
|
OS string `json:"os"`
|
||||||
|
RootFS RootFS `json:"rootfs"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// RootFS represents the root filesystem configuration for a model.
|
||||||
|
type RootFS struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
DiffIDs []string `json:"diff_ids"`
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue