From d2f334c1f7822efe3470f41720dc121e5b19e891 Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Mon, 8 Dec 2025 16:49:17 -0800 Subject: [PATCH] model: add rnj-1 inference support (#13354) --- convert/convert_gemma3.go | 64 +++++++++++++++++---- ml/nn/rope/options.go | 12 ++++ model/models/gemma3/model.go | 57 ++++++++++++------- model/models/gemma3/model_text.go | 94 ++++++++++++++++++++++--------- parser/parser.go | 19 +++---- parser/parser_test.go | 31 ++++++++++ 6 files changed, 208 insertions(+), 69 deletions(-) diff --git a/convert/convert_gemma3.go b/convert/convert_gemma3.go index 27b99f575..5e6e6904c 100644 --- a/convert/convert_gemma3.go +++ b/convert/convert_gemma3.go @@ -2,6 +2,7 @@ package convert import ( "cmp" + "slices" "github.com/ollama/ollama/fs/ggml" ) @@ -26,16 +27,26 @@ type gemma3Model struct { NumChannels uint32 `json:"num_channels"` // num_channels 3 PatchSize uint32 `json:"patch_size"` // patch_size 14 } `json:"vision_config"` - MaxPositionEmbeddings uint32 `json:"max_position_embeddings"` - NumAttentionHeads uint32 `json:"num_attention_heads"` - NumKeyValueHeads uint32 `json:"num_key_value_heads"` - RMSNormEPS float32 `json:"rms_norm_eps"` - HeadDim uint32 `json:"head_dim"` - FinalLogitSoftcap float32 `json:"final_logit_softcapping"` - RopeLocalTheta float32 `json:"rope_local_base_freq"` - RopeGlobalTheta float32 `json:"rope_global_base_freq"` - SlidingWindow uint32 `json:"sliding_window"` - MultiModalTokensPerImage uint32 `json:"mm_tokens_per_image"` + MaxPositionEmbeddings uint32 `json:"max_position_embeddings"` + NumAttentionHeads uint32 `json:"num_attention_heads"` + NumKeyValueHeads uint32 `json:"num_key_value_heads"` + RMSNormEPS float32 `json:"rms_norm_eps"` + HeadDim uint32 `json:"head_dim"` + FinalLogitSoftcap float32 `json:"final_logit_softcapping"` + RopeLocalTheta float32 `json:"rope_local_base_freq"` + RopeTheta float32 `json:"rope_theta"` + SlidingWindow uint32 `json:"sliding_window"` + SlidingWindowPattern *uint32 `json:"sliding_window_pattern"` + LayerTypes []string `json:"layer_types"` + MultiModalTokensPerImage uint32 `json:"mm_tokens_per_image"` + RopeScaling *struct { + Type string `json:"rope_type"` + Factor float32 `json:"factor"` + OriginalMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"` + ExtrapolationFactor float32 `json:"extrapolation_factor"` + BetaFast float32 `json:"beta_fast"` + BetaSlow float32 `json:"beta_slow"` + } `json:"rope_scaling"` } const ( @@ -81,9 +92,38 @@ func (p *gemma3Model) KV(t *Tokenizer) ggml.KV { kv["gemma3.attention.key_length"] = p.HeadDim kv["gemma3.attention.value_length"] = p.HeadDim kv["gemma3.attention.sliding_window"] = p.SlidingWindow - kv["gemma3.final_logit_softcapping"] = cmp.Or(p.FinalLogitSoftcap, 30) + + // The sliding window pattern is either provided as the sliding_window_pattern + // key (an int) or as the layer_types key (a list of strings). + if p.SlidingWindowPattern != nil || len(p.LayerTypes) > 0 { + kv["gemma3.attention.sliding_window_pattern"] = slices.Collect(func(yield func(bool) bool) { + for i := range numBlocks { + var isLocal bool + if len(p.LayerTypes) > 0 && int(i) < len(p.LayerTypes) { + isLocal = p.LayerTypes[i] == "sliding_attention" + } else if p.SlidingWindowPattern != nil && *p.SlidingWindowPattern > 0 { + isLocal = (i+1)%*p.SlidingWindowPattern != 0 + } + if !yield(isLocal) { + break + } + } + }) + } + if p.FinalLogitSoftcap > 0 { + kv["gemma3.final_logit_softcapping"] = p.FinalLogitSoftcap + } kv["gemma3.rope.local.freq_base"] = cmp.Or(p.RopeLocalTheta, 10000.0) - kv["gemma3.rope.global.freq_base"] = cmp.Or(p.RopeGlobalTheta, 1000000.0) + kv["gemma3.rope.freq_base"] = cmp.Or(p.RopeTheta, 1000000.0) + if p.RopeScaling != nil && p.RopeScaling.Type == "yarn" && p.RopeScaling.Factor > 0 { + kv["gemma3.rope.scaling.type"] = "yarn" + kv["gemma3.rope.scaling.factor"] = p.RopeScaling.Factor + kv["gemma3.rope.scaling.original_context_length"] = p.RopeScaling.OriginalMaxPositionEmbeddings + kv["gemma3.rope.scaling.extrapolation_factor"] = cmp.Or(p.RopeScaling.ExtrapolationFactor, float32(1.0)) + kv["gemma3.rope.scaling.beta_fast"] = cmp.Or(p.RopeScaling.BetaFast, float32(64.0)) + kv["gemma3.rope.scaling.beta_slow"] = cmp.Or(p.RopeScaling.BetaSlow, float32(1.0)) + } + kv["gemma3.embedding_length"] = p.HiddenSize kv["gemma3.feed_forward_length"] = p.IntermediateSize default: diff --git a/ml/nn/rope/options.go b/ml/nn/rope/options.go index 03cc5211a..84b926773 100644 --- a/ml/nn/rope/options.go +++ b/ml/nn/rope/options.go @@ -58,6 +58,18 @@ func WithAttentionFactor(attentionFactor float32) func(*Options) { } } +func WithBetaFast(betaFast float32) func(*Options) { + return func(opts *Options) { + opts.YaRN.BetaFast = betaFast + } +} + +func WithBetaSlow(betaSlow float32) func(*Options) { + return func(opts *Options) { + opts.YaRN.BetaSlow = betaSlow + } +} + func WithMRoPE(sections []int) func(*Options) { return func(opts *Options) { opts.Type |= 1 << 3 diff --git a/model/models/gemma3/model.go b/model/models/gemma3/model.go index 62f51074a..e595f1863 100644 --- a/model/models/gemma3/model.go +++ b/model/models/gemma3/model.go @@ -16,7 +16,7 @@ import ( type Model struct { model.Base - model.SentencePiece + model.TextProcessor *VisionModel `gguf:"v"` *TextModel @@ -54,24 +54,35 @@ func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, i } func New(c fs.Config) (model.Model, error) { - m := Model{ - SentencePiece: model.NewSentencePiece( - &model.Vocabulary{ - Values: c.Strings("tokenizer.ggml.tokens"), - Scores: c.Floats("tokenizer.ggml.scores"), - Types: c.Ints("tokenizer.ggml.token_type"), - AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), - BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))}, - AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), - EOS: append( - []int32{ - int32(c.Uint("tokenizer.ggml.eos_token_id")), - int32(c.Uint("tokenizer.ggml.eot_token_id", 106)), - }, - c.Ints("tokenizer.ggml.eos_token_ids")..., - ), + vocabulary := model.Vocabulary{ + Values: c.Strings("tokenizer.ggml.tokens"), + Scores: c.Floats("tokenizer.ggml.scores"), + Types: c.Ints("tokenizer.ggml.token_type"), + Merges: c.Strings("tokenizer.ggml.merges"), + AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), + BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))}, + AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), + EOS: append( + []int32{ + int32(c.Uint("tokenizer.ggml.eos_token_id")), }, + c.Ints("tokenizer.ggml.eos_token_ids")..., ), + } + + var processor model.TextProcessor + switch c.String("tokenizer.ggml.model") { + case "gpt2": + processor = model.NewBytePairEncoding(&vocabulary) + default: + // Previous uploads of Gemma 3 on Ollama did not have token 106 + // (i.e. "") so we need to add in case it's not already present + vocabulary.EOS = append(vocabulary.EOS, int32(c.Uint("tokenizer.ggml.eot_token_id", 106))) + processor = model.NewSentencePiece(&vocabulary) + } + + m := Model{ + TextProcessor: processor, ImageProcessor: newImageProcessor(c), VisionModel: newVisionModel(c), TextModel: newTextModel(c), @@ -141,8 +152,16 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) { } func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { - hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache) - return m.Output.Forward(ctx, hiddenStates), nil + hiddenState := m.TextModel.Forward(ctx, batch, m.Cache) + hiddenState = m.Output.Forward(ctx, hiddenState) + + if m.TextConfig.finalLogitSoftcap > 0.0 { + hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.TextConfig.finalLogitSoftcap)) + hiddenState = hiddenState.Tanh(ctx) + hiddenState = hiddenState.Scale(ctx, float64(m.TextConfig.finalLogitSoftcap)) + } + + return hiddenState, nil } func init() { diff --git a/model/models/gemma3/model_text.go b/model/models/gemma3/model_text.go index ddb30c415..f76fba741 100644 --- a/model/models/gemma3/model_text.go +++ b/model/models/gemma3/model_text.go @@ -2,6 +2,7 @@ package gemma3 import ( "math" + "slices" "github.com/ollama/ollama/fs" "github.com/ollama/ollama/kvcache" @@ -15,12 +16,32 @@ type TextConfig struct { hiddenSize, numHeads, numKVHeads int attnKeyLen, attnValLen int eps, ropeScale float32 - ropeLocalBase, ropeGlobalBase float32 + ropeLocalBase float32 largeModelScaling bool + slidingWindowPattern []bool + ropeBase float32 + ropeType string + ropeOriginalContext int + ropeExtrapolation float32 + ropeBetaFast float32 + ropeBetaSlow float32 + finalLogitSoftcap float32 } func (o TextConfig) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor, base float32) ml.Tensor { - return nn.RoPE(ctx, states, positions, o.attnKeyLen, base, 1./o.ropeScale, rope.WithTypeNeoX()) + ropeOpts := []func(*rope.Options){rope.WithTypeNeoX()} + if o.ropeType == "yarn" { + attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(o.ropeScale)))) + ropeOpts = append(ropeOpts, + rope.WithOriginalContextLength(o.ropeOriginalContext), + rope.WithExtrapolationFactor(o.ropeExtrapolation), + rope.WithAttentionFactor(attnFactor), + rope.WithBetaFast(o.ropeBetaFast), + rope.WithBetaSlow(o.ropeBetaSlow), + ) + } + + return nn.RoPE(ctx, states, positions, o.attnKeyLen, base, 1./o.ropeScale, ropeOpts...) } type TextModel struct { @@ -48,21 +69,35 @@ func newTextModel(c fs.Config) *TextModel { m := TextModel{ Layers: make([]TextLayer, numBlocks), TextConfig: &TextConfig{ - hiddenSize: int(c.Uint("embedding_length")), - numHeads: int(c.Uint("attention.head_count")), - numKVHeads: int(c.Uint("attention.head_count_kv")), - attnKeyLen: int(c.Uint("attention.key_length", 256)), - attnValLen: int(c.Uint("attention.value_length", 256)), - eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06), - ropeLocalBase: c.Float("rope.local.freq_base", 10000.0), - ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0), - ropeScale: 1, - // NOTE: the rope.scaling.factor is set incorrectly in the official QAT weights - // (8 instead of 1) - // ropeScale: c.Float("rope.scaling.factor", 1.0), + hiddenSize: int(c.Uint("embedding_length")), + numHeads: int(c.Uint("attention.head_count")), + numKVHeads: int(c.Uint("attention.head_count_kv")), + attnKeyLen: int(c.Uint("attention.key_length", 256)), + attnValLen: int(c.Uint("attention.value_length", 256)), + eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06), + ropeLocalBase: c.Float("rope.local.freq_base", 10000.0), + ropeBase: c.Float("rope.freq_base", 1000000.0), + slidingWindowPattern: c.Bools("attention.sliding_window_pattern"), + ropeType: c.String("rope.scaling.type"), + ropeOriginalContext: int(c.Uint("rope.scaling.original_context_length")), + ropeExtrapolation: c.Float("rope.scaling.extrapolation_factor", 1.0), + ropeBetaFast: c.Float("rope.scaling.beta_fast", 64.0), + ropeBetaSlow: c.Float("rope.scaling.beta_slow", 1.0), + ropeScale: c.Float("rope.scaling.factor", 1.0), + finalLogitSoftcap: c.Float("final_logit_softcapping", 0.0), }, } + // Google's Gemma 3 release with sliding window attention does + // not use final logit softcapping, and so force it to 0.0 + // TODO (jmorganca): this should ideally be set to 0.0 in the + // model configuration instead of here, as future versions of + // models may include both sliding window attention and final + // logit softcapping. + if slices.Contains(m.TextConfig.slidingWindowPattern, true) { + m.TextConfig.finalLogitSoftcap = 0.0 + } + if numBlocks == gemma27BLayerCount { m.largeModelScaling = true } @@ -79,13 +114,26 @@ type TextSelfAttention struct { Output *nn.Linear `gguf:"attn_output"` } +func (opts *TextConfig) ropeBaseForLayer(layer int) float32 { + if opts.slidingWindowPattern != nil && opts.slidingWindowPattern[layer] { + return opts.ropeLocalBase + } + + // Standard Gemma3: only every n-th layer is global, + // where n = gemmaGlobalCacheCount, otherwise use + // the local rope base + if (layer+1)%gemmaGlobalCacheCount > 0 { + return opts.ropeLocalBase + } + + // default to global rope base + return opts.ropeBase +} + func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor { batchSize := hiddenState.Dim(1) - ropeBase := opts.ropeLocalBase - if (layer+1)%gemmaGlobalCacheCount == 0 { - ropeBase = opts.ropeGlobalBase - } + ropeBase := opts.ropeBaseForLayer(layer) q := sa.Query.Forward(ctx, hiddenState) q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize) @@ -114,12 +162,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos } func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - ropeBase := m.TextConfig.ropeLocalBase - if (layer+1)%gemmaGlobalCacheCount == 0 { - ropeBase = m.TextConfig.ropeGlobalBase - } - - return m.applyRotaryPositionEmbeddings(ctx, key, shift, ropeBase), nil + return m.applyRotaryPositionEmbeddings(ctx, key, shift, m.TextConfig.ropeBaseForLayer(layer)), nil } type TextMLP struct { @@ -207,6 +250,5 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextConfig) } - hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps) - return hiddenState + return m.OutputNorm.Forward(ctx, hiddenState, m.eps) } diff --git a/parser/parser.go b/parser/parser.go index 7d52c3387..1f4764449 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -300,18 +300,13 @@ func filesForModel(path string) ([]string, error) { } files = append(files, js...) - // only include tokenizer.model is tokenizer.json is not present - if !slices.ContainsFunc(files, func(s string) bool { - return slices.Contains(strings.Split(s, string(os.PathSeparator)), "tokenizer.json") - }) { - if tks, _ := glob(filepath.Join(path, "tokenizer.model"), "application/octet-stream"); len(tks) > 0 { - // add tokenizer.model if it exists, tokenizer.json is automatically picked up by the previous glob - // tokenizer.model might be a unresolved git lfs reference; error if it is - files = append(files, tks...) - } else if tks, _ := glob(filepath.Join(path, "**/tokenizer.model"), "text/plain"); len(tks) > 0 { - // some times tokenizer.model is in a subdirectory (e.g. meta-llama/Meta-Llama-3-8B) - files = append(files, tks...) - } + // add tokenizer.model if it exists (tokenizer.json is automatically picked up by the previous glob) + // tokenizer.model might be a unresolved git lfs reference; error if it is + if tks, _ := glob(filepath.Join(path, "tokenizer.model"), "application/octet-stream"); len(tks) > 0 { + files = append(files, tks...) + } else if tks, _ := glob(filepath.Join(path, "**/tokenizer.model"), "text/plain"); len(tks) > 0 { + // some times tokenizer.model is in a subdirectory (e.g. meta-llama/Meta-Llama-3-8B) + files = append(files, tks...) } return files, nil diff --git a/parser/parser_test.go b/parser/parser_test.go index 3300aad3e..4b97e8c20 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -888,6 +888,37 @@ func TestFilesForModel(t *testing.T) { "tokenizer.json", }, }, + { + name: "safetensors with both tokenizer.json and tokenizer.model", + setup: func(dir string) error { + // Create binary content for tokenizer.model (application/octet-stream) + binaryContent := make([]byte, 512) + for i := range binaryContent { + binaryContent[i] = byte(i % 256) + } + files := []string{ + "model-00001-of-00001.safetensors", + "config.json", + "tokenizer.json", + } + for _, file := range files { + if err := os.WriteFile(filepath.Join(dir, file), []byte("test content"), 0o644); err != nil { + return err + } + } + // Write tokenizer.model as binary + if err := os.WriteFile(filepath.Join(dir, "tokenizer.model"), binaryContent, 0o644); err != nil { + return err + } + return nil + }, + wantFiles: []string{ + "model-00001-of-00001.safetensors", + "config.json", + "tokenizer.json", + "tokenizer.model", + }, + }, { name: "safetensors with consolidated files - prefers model files", setup: func(dir string) error {