model: add olmo3 and olmo3.1 (#13415)

This commit is contained in:
Parth Sareen 2025-12-15 15:20:04 -08:00 committed by GitHub
parent 2c639431b1
commit ffbe8e076d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 345 additions and 0 deletions

View File

@ -202,6 +202,8 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
conv = &qwen25VLModel{} conv = &qwen25VLModel{}
case "Qwen3VLForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration": case "Qwen3VLForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration":
conv = &qwen3VLModel{} conv = &qwen3VLModel{}
case "Olmo3ForCausalLM":
conv = &olmoModel{}
case "BertModel": case "BertModel":
conv = &bertModel{} conv = &bertModel{}
case "NomicBertModel", "NomicBertMoEModel": case "NomicBertModel", "NomicBertMoEModel":

117
convert/convert_olmo.go Normal file
View File

@ -0,0 +1,117 @@
package convert
import (
"cmp"
"github.com/ollama/ollama/fs/ggml"
)
type ropeScaling struct {
Factor float32 `json:"factor"`
OriginalMaxPositionEmbeds uint32 `json:"original_max_position_embeddings"`
AttentionFactor float32 `json:"attention_factor"`
BetaFast float32 `json:"beta_fast"`
BetaSlow float32 `json:"beta_slow"`
RopeType string `json:"rope_type"`
ExtrapolationFactor float32 `json:"extrapolation_factor"`
}
type olmoModel struct {
ModelParameters
HiddenSize uint32 `json:"hidden_size"`
NumHiddenLayers uint32 `json:"num_hidden_layers"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
RMSNormEPS float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
RopeScaling *ropeScaling `json:"rope_scaling"`
SlidingWindow uint32 `json:"sliding_window"`
LayerTypes []string `json:"layer_types"`
}
var _ ModelConverter = (*olmoModel)(nil)
func (p *olmoModel) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "olmo3"
kv["olmo3.block_count"] = p.NumHiddenLayers
kv["olmo3.context_length"] = p.MaxPositionEmbeddings
kv["olmo3.embedding_length"] = p.HiddenSize
kv["olmo3.feed_forward_length"] = p.IntermediateSize
kv["olmo3.attention.head_count"] = p.NumAttentionHeads
kv["olmo3.attention.head_count_kv"] = cmp.Or(p.NumKeyValueHeads, p.NumAttentionHeads)
if p.RopeTheta > 0 {
kv["olmo3.rope.freq_base"] = p.RopeTheta
}
if p.RopeScaling != nil {
if p.RopeScaling.Factor > 0 {
kv["olmo3.rope.scaling.factor"] = p.RopeScaling.Factor
}
if p.RopeScaling.OriginalMaxPositionEmbeds > 0 {
kv["olmo3.rope.scaling.original_context_length"] = p.RopeScaling.OriginalMaxPositionEmbeds
}
if p.RopeScaling.AttentionFactor > 0 {
kv["olmo3.rope.scaling.attn_factor"] = p.RopeScaling.AttentionFactor
}
if p.RopeScaling.RopeType != "" {
kv["olmo3.rope.scaling.type"] = p.RopeScaling.RopeType
}
}
if p.RMSNormEPS > 0 {
kv["olmo3.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
}
if p.SlidingWindow > 0 {
kv["olmo3.attention.sliding_window"] = p.SlidingWindow
}
if len(p.LayerTypes) > 0 {
slidingPattern := make([]bool, len(p.LayerTypes))
for i, layerType := range p.LayerTypes {
slidingPattern[i] = (layerType == "sliding_attention")
}
kv["olmo3.attention.sliding_window_pattern"] = slidingPattern
}
return kv
}
func (p *olmoModel) Tensors(ts []Tensor) []*ggml.Tensor {
out := make([]*ggml.Tensor, 0, len(ts))
for _, t := range ts {
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
return out
}
func (p *olmoModel) Replacements() []string {
return []string{
"lm_head", "output",
"model.embed_tokens", "token_embd",
"model.layers", "blk",
"model.norm", "output_norm",
"self_attn.q_proj", "attn_q",
"self_attn.k_proj", "attn_k",
"self_attn.v_proj", "attn_v",
"self_attn.o_proj", "attn_output",
"self_attn.q_norm", "attn_q_norm",
"self_attn.k_norm", "attn_k_norm",
"post_attention_layernorm", "post_attention_norm",
"post_feedforward_layernorm", "post_ffw_norm",
"mlp.gate_proj", "ffn_gate",
"mlp.down_proj", "ffn_down",
"mlp.up_proj", "ffn_up",
}
}

View File

@ -253,6 +253,7 @@ func (kv KV) OllamaEngineRequired() bool {
"deepseekocr", "deepseekocr",
"deepseek2", "deepseek2",
"nomic-bert", "nomic-bert",
"olmo3",
}, kv.Architecture()) }, kv.Architecture())
} }
@ -841,6 +842,7 @@ func (f GGML) FlashAttention() bool {
"gemma3", "gemma3",
"gptoss", "gpt-oss", "gptoss", "gpt-oss",
"mistral3", "mistral3",
"olmo3",
"qwen3", "qwen3moe", "qwen3", "qwen3moe",
"qwen3vl", "qwen3vlmoe", "qwen3vl", "qwen3vlmoe",
}, f.KV().String("general.architecture")) }, f.KV().String("general.architecture"))

View File

@ -13,6 +13,7 @@ import (
_ "github.com/ollama/ollama/model/models/mistral3" _ "github.com/ollama/ollama/model/models/mistral3"
_ "github.com/ollama/ollama/model/models/mllama" _ "github.com/ollama/ollama/model/models/mllama"
_ "github.com/ollama/ollama/model/models/nomicbert" _ "github.com/ollama/ollama/model/models/nomicbert"
_ "github.com/ollama/ollama/model/models/olmo3"
_ "github.com/ollama/ollama/model/models/qwen2" _ "github.com/ollama/ollama/model/models/qwen2"
_ "github.com/ollama/ollama/model/models/qwen25vl" _ "github.com/ollama/ollama/model/models/qwen25vl"
_ "github.com/ollama/ollama/model/models/qwen3" _ "github.com/ollama/ollama/model/models/qwen3"

223
model/models/olmo3/model.go Normal file
View File

@ -0,0 +1,223 @@
package olmo3
import (
"fmt"
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
const (
cacheTypeSWA = 0
cacheTypeCausal = 1
)
type Options struct {
hiddenSize, numHeads, numKVHeads int
eps, ropeBase, ropeScale float32
originalContextLength int
attnFactor float32
ropeType string
ropeExtrapolation float32
slidingWindowPattern []bool
}
type Model struct {
model.Base
model.TextProcessor
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
Options
}
func New(c fs.Config) (model.Model, error) {
vocabulary := model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
}
processor := model.NewBytePairEncoding(
&vocabulary,
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
)
m := Model{
TextProcessor: processor,
Layers: make([]Layer, c.Uint("block_count")),
Options: Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base", 1e4),
ropeScale: c.Float("rope.scaling.factor", 1),
originalContextLength: int(c.Uint("rope.scaling.original_context_length")),
attnFactor: c.Float("rope.scaling.attn_factor", 1),
ropeType: c.String("rope.scaling.type"),
ropeExtrapolation: c.Float("rope.scaling.extrapolation_factor", 1.0),
slidingWindowPattern: c.Bools("attention.sliding_window_pattern"),
},
}
m.Cache = kvcache.NewWrapperCache(
kvcache.NewSWACache(int32(c.Uint("attention.sliding_window")), m.Shift),
kvcache.NewCausalCache(m.Shift),
)
return &m, nil
}
type SelfAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
QNorm *nn.RMSNorm `gguf:"attn_q_norm"`
KNorm *nn.RMSNorm `gguf:"attn_k_norm"`
}
func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor, isSWA bool) ml.Tensor {
freqScale := float32(1.0)
ropeOpts := []func(*rope.Options){rope.WithTypeNeoX()}
if !isSWA {
freqScale = 1. / o.ropeScale
if o.originalContextLength > 0 {
ropeOpts = append(ropeOpts,
rope.WithOriginalContextLength(o.originalContextLength),
rope.WithExtrapolationFactor(o.ropeExtrapolation),
)
}
}
return nn.RoPE(ctx, states, positions, o.hiddenSize/o.numHeads, o.ropeBase, freqScale, ropeOpts...)
}
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tensor, cache kvcache.Cache, m *Model, isSWA bool) ml.Tensor {
batchSize := hiddenState.Dim(1)
headDim := m.hiddenSize / m.numHeads
query := sa.Query.Forward(ctx, hiddenState)
query = sa.QNorm.Forward(ctx, query, m.eps)
query = query.Reshape(ctx, headDim, m.numHeads, batchSize)
query = m.Options.applyRotaryPositionEmbeddings(ctx, query, positions, isSWA)
key := sa.Key.Forward(ctx, hiddenState)
key = sa.KNorm.Forward(ctx, key, m.eps)
key = key.Reshape(ctx, headDim, m.numKVHeads, batchSize)
key = m.Options.applyRotaryPositionEmbeddings(ctx, key, positions, isSWA)
value := sa.Value.Forward(ctx, hiddenState)
value = value.Reshape(ctx, headDim, m.numKVHeads, batchSize)
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache)
attention = attention.Reshape(ctx, m.hiddenSize, batchSize)
return sa.Output.Forward(ctx, attention)
}
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
isSWA := m.isSWALayer(layer)
return m.Options.applyRotaryPositionEmbeddings(ctx, key, shift, isSWA), nil
}
type MLP struct {
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
Gate *nn.Linear `gguf:"ffn_gate"`
}
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, m *Model) ml.Tensor {
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState)
}
type Layer struct {
SelfAttention *SelfAttention
PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_norm"`
MLP *MLP
PostFFWNorm *nn.RMSNorm `gguf:"post_ffw_norm"`
}
func (l *Layer) Forward(ctx ml.Context, hiddenState, positions, outputs ml.Tensor, cache kvcache.Cache, m *Model, isSWA bool) ml.Tensor {
residual := hiddenState
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positions, cache, m, isSWA)
if outputs != nil {
hiddenState = hiddenState.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, m.eps)
hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState
hiddenState = l.MLP.Forward(ctx, hiddenState, m)
hiddenState = l.PostFFWNorm.Forward(ctx, hiddenState, m.eps)
return hiddenState.Add(ctx, residual)
}
// OLMo3 has Sliding Window Attention (SWA) for 3 out of every 4 layers.
func (m *Model) isSWALayer(layerIdx int) bool {
return m.Options.slidingWindowPattern[layerIdx]
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
for i, layer := range m.Layers {
m.Cache.SetLayer(i)
cacheType := cacheTypeSWA
isSWA := m.isSWALayer(i)
if !isSWA {
cacheType = cacheTypeCausal
}
wc, ok := m.Cache.(*kvcache.WrapperCache)
if !ok {
return nil, fmt.Errorf("expected *kvcache.WrapperCache, got %T", m.Cache)
}
wc.SetLayerType(cacheType)
var outputs ml.Tensor
if i == len(m.Layers)-1 {
outputs = batch.Outputs
}
hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, m, isSWA)
}
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
return m.Output.Forward(ctx, hiddenState), nil
}
func init() {
model.Register("olmo3", New)
}