From ffbe8e076df9e2e67aab016ea3ec64822369b725 Mon Sep 17 00:00:00 2001 From: Parth Sareen Date: Mon, 15 Dec 2025 15:20:04 -0800 Subject: [PATCH] model: add olmo3 and olmo3.1 (#13415) --- convert/convert.go | 2 + convert/convert_olmo.go | 117 +++++++++++++++++++ fs/ggml/ggml.go | 2 + model/models/models.go | 1 + model/models/olmo3/model.go | 223 ++++++++++++++++++++++++++++++++++++ 5 files changed, 345 insertions(+) create mode 100644 convert/convert_olmo.go create mode 100644 model/models/olmo3/model.go diff --git a/convert/convert.go b/convert/convert.go index f08467952..a6d286683 100644 --- a/convert/convert.go +++ b/convert/convert.go @@ -202,6 +202,8 @@ func ConvertModel(fsys fs.FS, f *os.File) error { conv = &qwen25VLModel{} case "Qwen3VLForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration": conv = &qwen3VLModel{} + case "Olmo3ForCausalLM": + conv = &olmoModel{} case "BertModel": conv = &bertModel{} case "NomicBertModel", "NomicBertMoEModel": diff --git a/convert/convert_olmo.go b/convert/convert_olmo.go new file mode 100644 index 000000000..f75c68477 --- /dev/null +++ b/convert/convert_olmo.go @@ -0,0 +1,117 @@ +package convert + +import ( + "cmp" + + "github.com/ollama/ollama/fs/ggml" +) + +type ropeScaling struct { + Factor float32 `json:"factor"` + OriginalMaxPositionEmbeds uint32 `json:"original_max_position_embeddings"` + AttentionFactor float32 `json:"attention_factor"` + BetaFast float32 `json:"beta_fast"` + BetaSlow float32 `json:"beta_slow"` + RopeType string `json:"rope_type"` + ExtrapolationFactor float32 `json:"extrapolation_factor"` +} + +type olmoModel struct { + ModelParameters + + HiddenSize uint32 `json:"hidden_size"` + NumHiddenLayers uint32 `json:"num_hidden_layers"` + IntermediateSize uint32 `json:"intermediate_size"` + NumAttentionHeads uint32 `json:"num_attention_heads"` + NumKeyValueHeads uint32 `json:"num_key_value_heads"` + MaxPositionEmbeddings uint32 `json:"max_position_embeddings"` + RMSNormEPS float32 `json:"rms_norm_eps"` + RopeTheta float32 `json:"rope_theta"` + RopeScaling *ropeScaling `json:"rope_scaling"` + SlidingWindow uint32 `json:"sliding_window"` + LayerTypes []string `json:"layer_types"` +} + +var _ ModelConverter = (*olmoModel)(nil) + +func (p *olmoModel) KV(t *Tokenizer) ggml.KV { + kv := p.ModelParameters.KV(t) + kv["general.architecture"] = "olmo3" + kv["olmo3.block_count"] = p.NumHiddenLayers + kv["olmo3.context_length"] = p.MaxPositionEmbeddings + kv["olmo3.embedding_length"] = p.HiddenSize + kv["olmo3.feed_forward_length"] = p.IntermediateSize + kv["olmo3.attention.head_count"] = p.NumAttentionHeads + kv["olmo3.attention.head_count_kv"] = cmp.Or(p.NumKeyValueHeads, p.NumAttentionHeads) + + if p.RopeTheta > 0 { + kv["olmo3.rope.freq_base"] = p.RopeTheta + } + + if p.RopeScaling != nil { + if p.RopeScaling.Factor > 0 { + kv["olmo3.rope.scaling.factor"] = p.RopeScaling.Factor + } + if p.RopeScaling.OriginalMaxPositionEmbeds > 0 { + kv["olmo3.rope.scaling.original_context_length"] = p.RopeScaling.OriginalMaxPositionEmbeds + } + if p.RopeScaling.AttentionFactor > 0 { + kv["olmo3.rope.scaling.attn_factor"] = p.RopeScaling.AttentionFactor + } + if p.RopeScaling.RopeType != "" { + kv["olmo3.rope.scaling.type"] = p.RopeScaling.RopeType + } + } + + if p.RMSNormEPS > 0 { + kv["olmo3.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS + } + + if p.SlidingWindow > 0 { + kv["olmo3.attention.sliding_window"] = p.SlidingWindow + } + + if len(p.LayerTypes) > 0 { + slidingPattern := make([]bool, len(p.LayerTypes)) + for i, layerType := range p.LayerTypes { + slidingPattern[i] = (layerType == "sliding_attention") + } + kv["olmo3.attention.sliding_window_pattern"] = slidingPattern + } + + return kv +} + +func (p *olmoModel) Tensors(ts []Tensor) []*ggml.Tensor { + out := make([]*ggml.Tensor, 0, len(ts)) + for _, t := range ts { + out = append(out, &ggml.Tensor{ + Name: t.Name(), + Kind: t.Kind(), + Shape: t.Shape(), + WriterTo: t, + }) + } + + return out +} + +func (p *olmoModel) Replacements() []string { + return []string{ + "lm_head", "output", + "model.embed_tokens", "token_embd", + "model.layers", "blk", + "model.norm", "output_norm", + "self_attn.q_proj", "attn_q", + "self_attn.k_proj", "attn_k", + "self_attn.v_proj", "attn_v", + "self_attn.o_proj", "attn_output", + "self_attn.q_norm", "attn_q_norm", + "self_attn.k_norm", "attn_k_norm", + "post_attention_layernorm", "post_attention_norm", + "post_feedforward_layernorm", "post_ffw_norm", + "mlp.gate_proj", "ffn_gate", + "mlp.down_proj", "ffn_down", + "mlp.up_proj", "ffn_up", + } +} diff --git a/fs/ggml/ggml.go b/fs/ggml/ggml.go index 691ea32b4..56614a321 100644 --- a/fs/ggml/ggml.go +++ b/fs/ggml/ggml.go @@ -253,6 +253,7 @@ func (kv KV) OllamaEngineRequired() bool { "deepseekocr", "deepseek2", "nomic-bert", + "olmo3", }, kv.Architecture()) } @@ -841,6 +842,7 @@ func (f GGML) FlashAttention() bool { "gemma3", "gptoss", "gpt-oss", "mistral3", + "olmo3", "qwen3", "qwen3moe", "qwen3vl", "qwen3vlmoe", }, f.KV().String("general.architecture")) diff --git a/model/models/models.go b/model/models/models.go index 85bf9a7f3..b471e8166 100644 --- a/model/models/models.go +++ b/model/models/models.go @@ -13,6 +13,7 @@ import ( _ "github.com/ollama/ollama/model/models/mistral3" _ "github.com/ollama/ollama/model/models/mllama" _ "github.com/ollama/ollama/model/models/nomicbert" + _ "github.com/ollama/ollama/model/models/olmo3" _ "github.com/ollama/ollama/model/models/qwen2" _ "github.com/ollama/ollama/model/models/qwen25vl" _ "github.com/ollama/ollama/model/models/qwen3" diff --git a/model/models/olmo3/model.go b/model/models/olmo3/model.go new file mode 100644 index 000000000..523c00e68 --- /dev/null +++ b/model/models/olmo3/model.go @@ -0,0 +1,223 @@ +package olmo3 + +import ( + "fmt" + "math" + + "github.com/ollama/ollama/fs" + "github.com/ollama/ollama/kvcache" + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/ml/nn/rope" + "github.com/ollama/ollama/model" + "github.com/ollama/ollama/model/input" +) + +const ( + cacheTypeSWA = 0 + cacheTypeCausal = 1 +) + +type Options struct { + hiddenSize, numHeads, numKVHeads int + eps, ropeBase, ropeScale float32 + + originalContextLength int + attnFactor float32 + + ropeType string + ropeExtrapolation float32 + + slidingWindowPattern []bool +} + +type Model struct { + model.Base + model.TextProcessor + + TokenEmbedding *nn.Embedding `gguf:"token_embd"` + Layers []Layer `gguf:"blk"` + OutputNorm *nn.RMSNorm `gguf:"output_norm"` + Output *nn.Linear `gguf:"output,alt:token_embd"` + + Options +} + +func New(c fs.Config) (model.Model, error) { + vocabulary := model.Vocabulary{ + Values: c.Strings("tokenizer.ggml.tokens"), + Scores: c.Floats("tokenizer.ggml.scores"), + Types: c.Ints("tokenizer.ggml.token_type"), + Merges: c.Strings("tokenizer.ggml.merges"), + AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false), + BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))}, + AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), + EOS: append( + []int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))}, + c.Ints("tokenizer.ggml.eos_token_ids")..., + ), + } + + processor := model.NewBytePairEncoding( + &vocabulary, + "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + ) + + m := Model{ + TextProcessor: processor, + Layers: make([]Layer, c.Uint("block_count")), + Options: Options{ + hiddenSize: int(c.Uint("embedding_length")), + numHeads: int(c.Uint("attention.head_count")), + numKVHeads: int(c.Uint("attention.head_count_kv")), + eps: c.Float("attention.layer_norm_rms_epsilon"), + ropeBase: c.Float("rope.freq_base", 1e4), + ropeScale: c.Float("rope.scaling.factor", 1), + originalContextLength: int(c.Uint("rope.scaling.original_context_length")), + attnFactor: c.Float("rope.scaling.attn_factor", 1), + ropeType: c.String("rope.scaling.type"), + ropeExtrapolation: c.Float("rope.scaling.extrapolation_factor", 1.0), + slidingWindowPattern: c.Bools("attention.sliding_window_pattern"), + }, + } + + m.Cache = kvcache.NewWrapperCache( + kvcache.NewSWACache(int32(c.Uint("attention.sliding_window")), m.Shift), + kvcache.NewCausalCache(m.Shift), + ) + + return &m, nil +} + +type SelfAttention struct { + Query *nn.Linear `gguf:"attn_q"` + Key *nn.Linear `gguf:"attn_k"` + Value *nn.Linear `gguf:"attn_v"` + Output *nn.Linear `gguf:"attn_output"` + QNorm *nn.RMSNorm `gguf:"attn_q_norm"` + KNorm *nn.RMSNorm `gguf:"attn_k_norm"` +} + +func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor, isSWA bool) ml.Tensor { + freqScale := float32(1.0) + ropeOpts := []func(*rope.Options){rope.WithTypeNeoX()} + + if !isSWA { + freqScale = 1. / o.ropeScale + if o.originalContextLength > 0 { + ropeOpts = append(ropeOpts, + rope.WithOriginalContextLength(o.originalContextLength), + rope.WithExtrapolationFactor(o.ropeExtrapolation), + ) + } + } + + return nn.RoPE(ctx, states, positions, o.hiddenSize/o.numHeads, o.ropeBase, freqScale, ropeOpts...) +} + +func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tensor, cache kvcache.Cache, m *Model, isSWA bool) ml.Tensor { + batchSize := hiddenState.Dim(1) + headDim := m.hiddenSize / m.numHeads + + query := sa.Query.Forward(ctx, hiddenState) + query = sa.QNorm.Forward(ctx, query, m.eps) + query = query.Reshape(ctx, headDim, m.numHeads, batchSize) + query = m.Options.applyRotaryPositionEmbeddings(ctx, query, positions, isSWA) + + key := sa.Key.Forward(ctx, hiddenState) + key = sa.KNorm.Forward(ctx, key, m.eps) + key = key.Reshape(ctx, headDim, m.numKVHeads, batchSize) + key = m.Options.applyRotaryPositionEmbeddings(ctx, key, positions, isSWA) + + value := sa.Value.Forward(ctx, hiddenState) + value = value.Reshape(ctx, headDim, m.numKVHeads, batchSize) + + attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache) + attention = attention.Reshape(ctx, m.hiddenSize, batchSize) + + return sa.Output.Forward(ctx, attention) +} + +func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { + isSWA := m.isSWALayer(layer) + return m.Options.applyRotaryPositionEmbeddings(ctx, key, shift, isSWA), nil +} + +type MLP struct { + Up *nn.Linear `gguf:"ffn_up"` + Down *nn.Linear `gguf:"ffn_down"` + Gate *nn.Linear `gguf:"ffn_gate"` +} + +func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, m *Model) ml.Tensor { + hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState)) + return mlp.Down.Forward(ctx, hiddenState) +} + +type Layer struct { + SelfAttention *SelfAttention + PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_norm"` + MLP *MLP + PostFFWNorm *nn.RMSNorm `gguf:"post_ffw_norm"` +} + +func (l *Layer) Forward(ctx ml.Context, hiddenState, positions, outputs ml.Tensor, cache kvcache.Cache, m *Model, isSWA bool) ml.Tensor { + residual := hiddenState + + hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positions, cache, m, isSWA) + + if outputs != nil { + hiddenState = hiddenState.Rows(ctx, outputs) + residual = residual.Rows(ctx, outputs) + } + hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, m.eps) + + hiddenState = hiddenState.Add(ctx, residual) + residual = hiddenState + + hiddenState = l.MLP.Forward(ctx, hiddenState, m) + hiddenState = l.PostFFWNorm.Forward(ctx, hiddenState, m.eps) + + return hiddenState.Add(ctx, residual) +} + +// OLMo3 has Sliding Window Attention (SWA) for 3 out of every 4 layers. +func (m *Model) isSWALayer(layerIdx int) bool { + return m.Options.slidingWindowPattern[layerIdx] +} + +func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { + positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions)) + + hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs) + + for i, layer := range m.Layers { + m.Cache.SetLayer(i) + cacheType := cacheTypeSWA + + isSWA := m.isSWALayer(i) + if !isSWA { + cacheType = cacheTypeCausal + } + + wc, ok := m.Cache.(*kvcache.WrapperCache) + if !ok { + return nil, fmt.Errorf("expected *kvcache.WrapperCache, got %T", m.Cache) + } + wc.SetLayerType(cacheType) + + var outputs ml.Tensor + if i == len(m.Layers)-1 { + outputs = batch.Outputs + } + + hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, m, isSWA) + } + + hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps) + return m.Output.Forward(ctx, hiddenState), nil +} + +func init() { + model.Register("olmo3", New) +}