model: add rnj-1 inference support (#13354)

This commit is contained in:
Jeffrey Morgan 2025-12-08 16:49:17 -08:00 committed by GitHub
parent 603ceefaa6
commit d2f334c1f7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 208 additions and 69 deletions

View File

@ -2,6 +2,7 @@ package convert
import (
"cmp"
"slices"
"github.com/ollama/ollama/fs/ggml"
)
@ -26,16 +27,26 @@ type gemma3Model struct {
NumChannels uint32 `json:"num_channels"` // num_channels 3
PatchSize uint32 `json:"patch_size"` // patch_size 14
} `json:"vision_config"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
RMSNormEPS float32 `json:"rms_norm_eps"`
HeadDim uint32 `json:"head_dim"`
FinalLogitSoftcap float32 `json:"final_logit_softcapping"`
RopeLocalTheta float32 `json:"rope_local_base_freq"`
RopeGlobalTheta float32 `json:"rope_global_base_freq"`
SlidingWindow uint32 `json:"sliding_window"`
MultiModalTokensPerImage uint32 `json:"mm_tokens_per_image"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
RMSNormEPS float32 `json:"rms_norm_eps"`
HeadDim uint32 `json:"head_dim"`
FinalLogitSoftcap float32 `json:"final_logit_softcapping"`
RopeLocalTheta float32 `json:"rope_local_base_freq"`
RopeTheta float32 `json:"rope_theta"`
SlidingWindow uint32 `json:"sliding_window"`
SlidingWindowPattern *uint32 `json:"sliding_window_pattern"`
LayerTypes []string `json:"layer_types"`
MultiModalTokensPerImage uint32 `json:"mm_tokens_per_image"`
RopeScaling *struct {
Type string `json:"rope_type"`
Factor float32 `json:"factor"`
OriginalMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
ExtrapolationFactor float32 `json:"extrapolation_factor"`
BetaFast float32 `json:"beta_fast"`
BetaSlow float32 `json:"beta_slow"`
} `json:"rope_scaling"`
}
const (
@ -81,9 +92,38 @@ func (p *gemma3Model) KV(t *Tokenizer) ggml.KV {
kv["gemma3.attention.key_length"] = p.HeadDim
kv["gemma3.attention.value_length"] = p.HeadDim
kv["gemma3.attention.sliding_window"] = p.SlidingWindow
kv["gemma3.final_logit_softcapping"] = cmp.Or(p.FinalLogitSoftcap, 30)
// The sliding window pattern is either provided as the sliding_window_pattern
// key (an int) or as the layer_types key (a list of strings).
if p.SlidingWindowPattern != nil || len(p.LayerTypes) > 0 {
kv["gemma3.attention.sliding_window_pattern"] = slices.Collect(func(yield func(bool) bool) {
for i := range numBlocks {
var isLocal bool
if len(p.LayerTypes) > 0 && int(i) < len(p.LayerTypes) {
isLocal = p.LayerTypes[i] == "sliding_attention"
} else if p.SlidingWindowPattern != nil && *p.SlidingWindowPattern > 0 {
isLocal = (i+1)%*p.SlidingWindowPattern != 0
}
if !yield(isLocal) {
break
}
}
})
}
if p.FinalLogitSoftcap > 0 {
kv["gemma3.final_logit_softcapping"] = p.FinalLogitSoftcap
}
kv["gemma3.rope.local.freq_base"] = cmp.Or(p.RopeLocalTheta, 10000.0)
kv["gemma3.rope.global.freq_base"] = cmp.Or(p.RopeGlobalTheta, 1000000.0)
kv["gemma3.rope.freq_base"] = cmp.Or(p.RopeTheta, 1000000.0)
if p.RopeScaling != nil && p.RopeScaling.Type == "yarn" && p.RopeScaling.Factor > 0 {
kv["gemma3.rope.scaling.type"] = "yarn"
kv["gemma3.rope.scaling.factor"] = p.RopeScaling.Factor
kv["gemma3.rope.scaling.original_context_length"] = p.RopeScaling.OriginalMaxPositionEmbeddings
kv["gemma3.rope.scaling.extrapolation_factor"] = cmp.Or(p.RopeScaling.ExtrapolationFactor, float32(1.0))
kv["gemma3.rope.scaling.beta_fast"] = cmp.Or(p.RopeScaling.BetaFast, float32(64.0))
kv["gemma3.rope.scaling.beta_slow"] = cmp.Or(p.RopeScaling.BetaSlow, float32(1.0))
}
kv["gemma3.embedding_length"] = p.HiddenSize
kv["gemma3.feed_forward_length"] = p.IntermediateSize
default:

View File

@ -58,6 +58,18 @@ func WithAttentionFactor(attentionFactor float32) func(*Options) {
}
}
func WithBetaFast(betaFast float32) func(*Options) {
return func(opts *Options) {
opts.YaRN.BetaFast = betaFast
}
}
func WithBetaSlow(betaSlow float32) func(*Options) {
return func(opts *Options) {
opts.YaRN.BetaSlow = betaSlow
}
}
func WithMRoPE(sections []int) func(*Options) {
return func(opts *Options) {
opts.Type |= 1 << 3

View File

@ -16,7 +16,7 @@ import (
type Model struct {
model.Base
model.SentencePiece
model.TextProcessor
*VisionModel `gguf:"v"`
*TextModel
@ -54,24 +54,35 @@ func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, i
}
func New(c fs.Config) (model.Model, error) {
m := Model{
SentencePiece: model.NewSentencePiece(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{
int32(c.Uint("tokenizer.ggml.eos_token_id")),
int32(c.Uint("tokenizer.ggml.eot_token_id", 106)),
},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
vocabulary := model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{
int32(c.Uint("tokenizer.ggml.eos_token_id")),
},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
}
var processor model.TextProcessor
switch c.String("tokenizer.ggml.model") {
case "gpt2":
processor = model.NewBytePairEncoding(&vocabulary)
default:
// Previous uploads of Gemma 3 on Ollama did not have token 106
// (i.e. "<end_of_turn>") so we need to add in case it's not already present
vocabulary.EOS = append(vocabulary.EOS, int32(c.Uint("tokenizer.ggml.eot_token_id", 106)))
processor = model.NewSentencePiece(&vocabulary)
}
m := Model{
TextProcessor: processor,
ImageProcessor: newImageProcessor(c),
VisionModel: newVisionModel(c),
TextModel: newTextModel(c),
@ -141,8 +152,16 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache)
return m.Output.Forward(ctx, hiddenStates), nil
hiddenState := m.TextModel.Forward(ctx, batch, m.Cache)
hiddenState = m.Output.Forward(ctx, hiddenState)
if m.TextConfig.finalLogitSoftcap > 0.0 {
hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.TextConfig.finalLogitSoftcap))
hiddenState = hiddenState.Tanh(ctx)
hiddenState = hiddenState.Scale(ctx, float64(m.TextConfig.finalLogitSoftcap))
}
return hiddenState, nil
}
func init() {

View File

@ -2,6 +2,7 @@ package gemma3
import (
"math"
"slices"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
@ -15,12 +16,32 @@ type TextConfig struct {
hiddenSize, numHeads, numKVHeads int
attnKeyLen, attnValLen int
eps, ropeScale float32
ropeLocalBase, ropeGlobalBase float32
ropeLocalBase float32
largeModelScaling bool
slidingWindowPattern []bool
ropeBase float32
ropeType string
ropeOriginalContext int
ropeExtrapolation float32
ropeBetaFast float32
ropeBetaSlow float32
finalLogitSoftcap float32
}
func (o TextConfig) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor, base float32) ml.Tensor {
return nn.RoPE(ctx, states, positions, o.attnKeyLen, base, 1./o.ropeScale, rope.WithTypeNeoX())
ropeOpts := []func(*rope.Options){rope.WithTypeNeoX()}
if o.ropeType == "yarn" {
attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(o.ropeScale))))
ropeOpts = append(ropeOpts,
rope.WithOriginalContextLength(o.ropeOriginalContext),
rope.WithExtrapolationFactor(o.ropeExtrapolation),
rope.WithAttentionFactor(attnFactor),
rope.WithBetaFast(o.ropeBetaFast),
rope.WithBetaSlow(o.ropeBetaSlow),
)
}
return nn.RoPE(ctx, states, positions, o.attnKeyLen, base, 1./o.ropeScale, ropeOpts...)
}
type TextModel struct {
@ -48,21 +69,35 @@ func newTextModel(c fs.Config) *TextModel {
m := TextModel{
Layers: make([]TextLayer, numBlocks),
TextConfig: &TextConfig{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
attnKeyLen: int(c.Uint("attention.key_length", 256)),
attnValLen: int(c.Uint("attention.value_length", 256)),
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
ropeLocalBase: c.Float("rope.local.freq_base", 10000.0),
ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0),
ropeScale: 1,
// NOTE: the rope.scaling.factor is set incorrectly in the official QAT weights
// (8 instead of 1)
// ropeScale: c.Float("rope.scaling.factor", 1.0),
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
attnKeyLen: int(c.Uint("attention.key_length", 256)),
attnValLen: int(c.Uint("attention.value_length", 256)),
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
ropeLocalBase: c.Float("rope.local.freq_base", 10000.0),
ropeBase: c.Float("rope.freq_base", 1000000.0),
slidingWindowPattern: c.Bools("attention.sliding_window_pattern"),
ropeType: c.String("rope.scaling.type"),
ropeOriginalContext: int(c.Uint("rope.scaling.original_context_length")),
ropeExtrapolation: c.Float("rope.scaling.extrapolation_factor", 1.0),
ropeBetaFast: c.Float("rope.scaling.beta_fast", 64.0),
ropeBetaSlow: c.Float("rope.scaling.beta_slow", 1.0),
ropeScale: c.Float("rope.scaling.factor", 1.0),
finalLogitSoftcap: c.Float("final_logit_softcapping", 0.0),
},
}
// Google's Gemma 3 release with sliding window attention does
// not use final logit softcapping, and so force it to 0.0
// TODO (jmorganca): this should ideally be set to 0.0 in the
// model configuration instead of here, as future versions of
// models may include both sliding window attention and final
// logit softcapping.
if slices.Contains(m.TextConfig.slidingWindowPattern, true) {
m.TextConfig.finalLogitSoftcap = 0.0
}
if numBlocks == gemma27BLayerCount {
m.largeModelScaling = true
}
@ -79,13 +114,26 @@ type TextSelfAttention struct {
Output *nn.Linear `gguf:"attn_output"`
}
func (opts *TextConfig) ropeBaseForLayer(layer int) float32 {
if opts.slidingWindowPattern != nil && opts.slidingWindowPattern[layer] {
return opts.ropeLocalBase
}
// Standard Gemma3: only every n-th layer is global,
// where n = gemmaGlobalCacheCount, otherwise use
// the local rope base
if (layer+1)%gemmaGlobalCacheCount > 0 {
return opts.ropeLocalBase
}
// default to global rope base
return opts.ropeBase
}
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
batchSize := hiddenState.Dim(1)
ropeBase := opts.ropeLocalBase
if (layer+1)%gemmaGlobalCacheCount == 0 {
ropeBase = opts.ropeGlobalBase
}
ropeBase := opts.ropeBaseForLayer(layer)
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
@ -114,12 +162,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
}
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
ropeBase := m.TextConfig.ropeLocalBase
if (layer+1)%gemmaGlobalCacheCount == 0 {
ropeBase = m.TextConfig.ropeGlobalBase
}
return m.applyRotaryPositionEmbeddings(ctx, key, shift, ropeBase), nil
return m.applyRotaryPositionEmbeddings(ctx, key, shift, m.TextConfig.ropeBaseForLayer(layer)), nil
}
type TextMLP struct {
@ -207,6 +250,5 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac
hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextConfig)
}
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
return hiddenState
return m.OutputNorm.Forward(ctx, hiddenState, m.eps)
}

View File

@ -300,18 +300,13 @@ func filesForModel(path string) ([]string, error) {
}
files = append(files, js...)
// only include tokenizer.model is tokenizer.json is not present
if !slices.ContainsFunc(files, func(s string) bool {
return slices.Contains(strings.Split(s, string(os.PathSeparator)), "tokenizer.json")
}) {
if tks, _ := glob(filepath.Join(path, "tokenizer.model"), "application/octet-stream"); len(tks) > 0 {
// add tokenizer.model if it exists, tokenizer.json is automatically picked up by the previous glob
// tokenizer.model might be a unresolved git lfs reference; error if it is
files = append(files, tks...)
} else if tks, _ := glob(filepath.Join(path, "**/tokenizer.model"), "text/plain"); len(tks) > 0 {
// some times tokenizer.model is in a subdirectory (e.g. meta-llama/Meta-Llama-3-8B)
files = append(files, tks...)
}
// add tokenizer.model if it exists (tokenizer.json is automatically picked up by the previous glob)
// tokenizer.model might be a unresolved git lfs reference; error if it is
if tks, _ := glob(filepath.Join(path, "tokenizer.model"), "application/octet-stream"); len(tks) > 0 {
files = append(files, tks...)
} else if tks, _ := glob(filepath.Join(path, "**/tokenizer.model"), "text/plain"); len(tks) > 0 {
// some times tokenizer.model is in a subdirectory (e.g. meta-llama/Meta-Llama-3-8B)
files = append(files, tks...)
}
return files, nil

View File

@ -888,6 +888,37 @@ func TestFilesForModel(t *testing.T) {
"tokenizer.json",
},
},
{
name: "safetensors with both tokenizer.json and tokenizer.model",
setup: func(dir string) error {
// Create binary content for tokenizer.model (application/octet-stream)
binaryContent := make([]byte, 512)
for i := range binaryContent {
binaryContent[i] = byte(i % 256)
}
files := []string{
"model-00001-of-00001.safetensors",
"config.json",
"tokenizer.json",
}
for _, file := range files {
if err := os.WriteFile(filepath.Join(dir, file), []byte("test content"), 0o644); err != nil {
return err
}
}
// Write tokenizer.model as binary
if err := os.WriteFile(filepath.Join(dir, "tokenizer.model"), binaryContent, 0o644); err != nil {
return err
}
return nil
},
wantFiles: []string{
"model-00001-of-00001.safetensors",
"config.json",
"tokenizer.json",
"tokenizer.model",
},
},
{
name: "safetensors with consolidated files - prefers model files",
setup: func(dir string) error {