From 603ceefaa67feee627e01cae1df1e0642e1c868f Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 18 Nov 2025 15:17:03 -0800 Subject: [PATCH] refactor rope change to a flatter directory structure and group the options with the function update models to call rope in one place --- ml/nn/{fast => }/rope.go | 5 ++--- ml/nn/rope/{rope.go => options.go} | 1 + model/models/deepseek2/model.go | 16 +++++++--------- model/models/deepseekocr/model_text.go | 11 +++++------ model/models/gemma2/model.go | 11 +++++++---- model/models/gemma3/model_text.go | 11 +++++++---- model/models/gemma3n/model_text.go | 11 +++++++---- model/models/gptoss/model.go | 17 ++++++++--------- model/models/llama/model.go | 13 +++++++------ model/models/llama4/model_text.go | 11 +++++++---- model/models/mistral3/model_text.go | 11 +++++++---- model/models/mistral3/model_vision.go | 8 ++++---- model/models/mllama/model_text.go | 13 ++++++++----- model/models/nomicbert/model.go | 9 ++++++--- model/models/qwen2/model.go | 13 +++++++------ model/models/qwen25vl/model_text.go | 14 ++++++++++---- model/models/qwen25vl/model_vision.go | 8 ++++---- model/models/qwen3/model.go | 3 +-- model/models/qwen3vl/model.go | 2 +- model/models/qwen3vl/model_text.go | 9 ++++----- model/models/qwen3vl/model_vision.go | 8 ++++---- 21 files changed, 114 insertions(+), 91 deletions(-) rename ml/nn/{fast => }/rope.go (71%) rename ml/nn/rope/{rope.go => options.go} (97%) diff --git a/ml/nn/fast/rope.go b/ml/nn/rope.go similarity index 71% rename from ml/nn/fast/rope.go rename to ml/nn/rope.go index b45938ebf..967aa94f9 100644 --- a/ml/nn/fast/rope.go +++ b/ml/nn/rope.go @@ -1,5 +1,4 @@ -// fast provides implementations of fast (fused) operations for increased performance. -package fast +package nn import ( "github.com/ollama/ollama/ml" @@ -8,7 +7,7 @@ import ( // fastRoPE is an interface for tensors that support fast rotary positional embedding. type fastRoPE interface { - RoPE(ctx ml.Context, positionIDs ml.Tensor, dim int, base, scale float32, options ...func(*rope.Options)) ml.Tensor + RoPE(ctx ml.Context, positions ml.Tensor, dim int, base, scale float32, options ...func(*rope.Options)) ml.Tensor } // RoPE applies rotary positional embedding to tensor `t`. diff --git a/ml/nn/rope/rope.go b/ml/nn/rope/options.go similarity index 97% rename from ml/nn/rope/rope.go rename to ml/nn/rope/options.go index e01ac152a..03cc5211a 100644 --- a/ml/nn/rope/rope.go +++ b/ml/nn/rope/options.go @@ -1,3 +1,4 @@ +// Package rope provides options for RoPE package rope import "github.com/ollama/ollama/ml" diff --git a/model/models/deepseek2/model.go b/model/models/deepseek2/model.go index e3cab3b23..576076aab 100644 --- a/model/models/deepseek2/model.go +++ b/model/models/deepseek2/model.go @@ -10,7 +10,6 @@ import ( "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" - "github.com/ollama/ollama/ml/nn/fast" "github.com/ollama/ollama/ml/nn/rope" "github.com/ollama/ollama/model" "github.com/ollama/ollama/model/input" @@ -42,13 +41,12 @@ type Options struct { kqScale float64 } -func (o Options) RoPEOptions() []func(*rope.Options) { - attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(o.ropeScale)))) - return []func(*rope.Options){ +func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, t, p ml.Tensor) ml.Tensor { + return nn.RoPE(ctx, t, p, o.qkRopeHeadDim, o.ropeBase, 1./o.ropeScale, rope.WithOriginalContextLength(o.originalContextLength), rope.WithExtrapolationFactor(1.), - rope.WithAttentionFactor(attnFactor), - } + rope.WithAttentionFactor(float32(1.0/(1.0+0.1*math.Log(float64(o.ropeScale))))), + ) } type Attention struct { @@ -91,8 +89,8 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor compressedKV.Stride(1), compressedKV.Dim(1), ) - qRot := fast.RoPE(ctx, queryChunks[1], positions, opts.qkRopeHeadDim, opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...) - kRot = fast.RoPE(ctx, kRot, positions, opts.qkRopeHeadDim, opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...) + qRot := opts.applyRotaryPositionEmbeddings(ctx, queryChunks[1], positions) + kRot = opts.applyRotaryPositionEmbeddings(ctx, kRot, positions) kPass = attn.KVANorm.Forward(ctx, kPass, opts.eps) var attention ml.Tensor @@ -327,7 +325,7 @@ func New(c fs.Config) (model.Model, error) { } func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return fast.RoPE(ctx, key, shift, m.qkRopeHeadDim, m.ropeBase, 1./m.ropeScale, m.RoPEOptions()...), nil + return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil } func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { diff --git a/model/models/deepseekocr/model_text.go b/model/models/deepseekocr/model_text.go index 1513b1388..ab6221ccf 100644 --- a/model/models/deepseekocr/model_text.go +++ b/model/models/deepseekocr/model_text.go @@ -6,7 +6,6 @@ import ( "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" - "github.com/ollama/ollama/ml/nn/fast" "github.com/ollama/ollama/ml/nn/rope" ) @@ -20,7 +19,7 @@ type textModel struct { } func (m *textModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return m.Options.applyRotaryPositionalEmbedding(ctx, key, shift), nil + return m.Options.applyRotaryPositionEmbeddings(ctx, key, shift), nil } type textOptions struct { @@ -38,8 +37,8 @@ func (o textOptions) headDim() int { return o.hiddenSize / o.numHeads } -func (o textOptions) applyRotaryPositionalEmbedding(ctx ml.Context, t, p ml.Tensor) ml.Tensor { - return fast.RoPE(ctx, t, p, o.headDim(), o.ropeBase, 1/o.ropeScale, rope.WithTypeNeoX()) +func (o textOptions) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor { + return nn.RoPE(ctx, states, positions, o.headDim(), o.ropeBase, 1/o.ropeScale, rope.WithTypeNeoX()) } type textBlock struct { @@ -83,8 +82,8 @@ func (m *textAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tenso value := m.Value.Forward(ctx, hiddenStates) value = value.Reshape(ctx, opts.headDim(), opts.numKVHeads, -1) - query = opts.applyRotaryPositionalEmbedding(ctx, query, positions) - key = opts.applyRotaryPositionalEmbedding(ctx, key, positions) + query = opts.applyRotaryPositionEmbeddings(ctx, query, positions) + key = opts.applyRotaryPositionEmbeddings(ctx, key, positions) attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache) attention = attention.Reshape(ctx, -1, attention.Dim(2)) diff --git a/model/models/gemma2/model.go b/model/models/gemma2/model.go index 06c71fc3b..7b0aa2f01 100644 --- a/model/models/gemma2/model.go +++ b/model/models/gemma2/model.go @@ -7,7 +7,6 @@ import ( "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" - "github.com/ollama/ollama/ml/nn/fast" "github.com/ollama/ollama/ml/nn/rope" "github.com/ollama/ollama/model" "github.com/ollama/ollama/model/input" @@ -22,6 +21,10 @@ type Options struct { largeModelScaling bool } +func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor { + return nn.RoPE(ctx, states, positions, o.attnKeyLen, o.ropeBase, 1./o.ropeScale, rope.WithTypeNeoX()) +} + type Model struct { model.Base model.SentencePiece @@ -88,7 +91,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten q := sa.Query.Forward(ctx, hiddenState) q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize) - q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX()) + q = opts.applyRotaryPositionEmbeddings(ctx, q, positionIDs) if opts.largeModelScaling { q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads))) @@ -98,7 +101,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten k := sa.Key.Forward(ctx, hiddenState) k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize) - k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX()) + k = opts.applyRotaryPositionEmbeddings(ctx, k, positionIDs) v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize) @@ -128,7 +131,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten } func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return fast.RoPE(ctx, key, shift, m.Options.attnKeyLen, m.Options.ropeBase, 1/m.Options.ropeScale, rope.WithTypeNeoX()), nil + return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil } type MLP struct { diff --git a/model/models/gemma3/model_text.go b/model/models/gemma3/model_text.go index 8d1a1be6a..ddb30c415 100644 --- a/model/models/gemma3/model_text.go +++ b/model/models/gemma3/model_text.go @@ -7,7 +7,6 @@ import ( "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" - "github.com/ollama/ollama/ml/nn/fast" "github.com/ollama/ollama/ml/nn/rope" "github.com/ollama/ollama/model/input" ) @@ -20,6 +19,10 @@ type TextConfig struct { largeModelScaling bool } +func (o TextConfig) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor, base float32) ml.Tensor { + return nn.RoPE(ctx, states, positions, o.attnKeyLen, base, 1./o.ropeScale, rope.WithTypeNeoX()) +} + type TextModel struct { TokenEmbedding *nn.Embedding `gguf:"token_embd"` Layers []TextLayer `gguf:"blk"` @@ -87,7 +90,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos q := sa.Query.Forward(ctx, hiddenState) q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize) q = sa.QueryNorm.Forward(ctx, q, opts.eps) - q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX()) + q = opts.applyRotaryPositionEmbeddings(ctx, q, positionIDs, ropeBase) if opts.largeModelScaling { q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads))) @@ -98,7 +101,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos k := sa.Key.Forward(ctx, hiddenState) k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize) k = sa.KeyNorm.Forward(ctx, k, opts.eps) - k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX()) + k = opts.applyRotaryPositionEmbeddings(ctx, k, positionIDs, ropeBase) v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize) @@ -116,7 +119,7 @@ func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.T ropeBase = m.TextConfig.ropeGlobalBase } - return fast.RoPE(ctx, key, shift, m.TextConfig.attnKeyLen, ropeBase, 1/m.TextConfig.ropeScale, rope.WithTypeNeoX()), nil + return m.applyRotaryPositionEmbeddings(ctx, key, shift, ropeBase), nil } type TextMLP struct { diff --git a/model/models/gemma3n/model_text.go b/model/models/gemma3n/model_text.go index 3a89afe72..89cc54b8b 100644 --- a/model/models/gemma3n/model_text.go +++ b/model/models/gemma3n/model_text.go @@ -8,7 +8,6 @@ import ( "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" - "github.com/ollama/ollama/ml/nn/fast" "github.com/ollama/ollama/ml/nn/rope" "github.com/ollama/ollama/model/input" ) @@ -95,7 +94,7 @@ func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.T ropeBase = m.ropeBaseLocal } - return fast.RoPE(ctx, key, shift, m.headDim(), ropeBase, 1./m.ropeScale, rope.WithTypeNeoX()), nil + return m.applyRotaryPositionEmbeddings(ctx, key, shift, ropeBase), nil } type TextScaledWordEmbedding struct { @@ -256,14 +255,14 @@ func (attn TextAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Ten query := attn.Query.Forward(ctx, hiddenStates) query = query.Reshape(ctx, opts.headDim(), opts.numHeads, batchSize) query = attn.QueryNorm.Forward(ctx, query, opts.eps) - query = fast.RoPE(ctx, query, positions, opts.headDim(), ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX()) + query = opts.applyRotaryPositionEmbeddings(ctx, query, positions, ropeBase) var key, value ml.Tensor if !sharedKV { key = attn.Key.Forward(ctx, hiddenStates) key = key.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize) key = attn.KeyNorm.Forward(ctx, key, opts.eps) - key = fast.RoPE(ctx, key, positions, opts.headDim(), ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX()) + key = opts.applyRotaryPositionEmbeddings(ctx, key, positions, ropeBase) value = attn.Value.Forward(ctx, hiddenStates) value = value.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize) @@ -330,6 +329,10 @@ func (o *TextOptions) isLocal(i int) bool { return o.slidingWindowPattern[i] } +func (o TextOptions) applyRotaryPositionEmbeddings(ctx ml.Context, t, p ml.Tensor, base float32) ml.Tensor { + return nn.RoPE(ctx, t, p, o.headDim(), base, 1./o.ropeScale, rope.WithTypeNeoX()) +} + func newTextModel(c fs.Config) *TextModel { return &TextModel{ TextLayers: make([]TextLayer, c.Uint("block_count")), diff --git a/model/models/gptoss/model.go b/model/models/gptoss/model.go index da08ed96f..9d1520bf3 100644 --- a/model/models/gptoss/model.go +++ b/model/models/gptoss/model.go @@ -9,7 +9,6 @@ import ( "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" - "github.com/ollama/ollama/ml/nn/fast" "github.com/ollama/ollama/ml/nn/rope" "github.com/ollama/ollama/model" "github.com/ollama/ollama/model/input" @@ -52,7 +51,7 @@ func (m *Transformer) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, err } func (m *Transformer) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return fast.RoPE(ctx, key, shift, m.headDim(), m.ropeBase, 1./m.ropeScale, m.RoPEOptions()...), nil + return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil } type Options struct { @@ -70,14 +69,14 @@ type Options struct { ropeScale float32 } -func (o Options) RoPEOptions() []func(*rope.Options) { - return []func(*rope.Options){ +func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor { + return nn.RoPE(ctx, states, positions, o.headDim(), o.ropeBase, 1./o.ropeScale, rope.WithTypeNeoX(), rope.WithOriginalContextLength(o.originalContextLength), rope.WithExtrapolationFactor(1.), - // NOTE: ggml sets this implicitly so there's no need to set it here - // rope.WithAttentionFactor(0.1*float32(math.Log(float64(o.ropeScale))) + 1.0), - } + // NOTE: ggml sets this implicitly so there's no need to set it here + // rope.WithAttentionFactor(0.1*float32(math.Log(float64(o.ropeScale))) + 1.0), + ) } func (o Options) headDim() int { @@ -135,8 +134,8 @@ func (attn *AttentionBlock) Forward(ctx ml.Context, hiddenStates, positions ml.T value = value.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize) } - query = fast.RoPE(ctx, query, positions, opts.headDim(), opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...) - key = fast.RoPE(ctx, key, positions, opts.headDim(), opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...) + query = opts.applyRotaryPositionEmbeddings(ctx, query, positions) + key = opts.applyRotaryPositionEmbeddings(ctx, key, positions) attention := nn.AttentionWithSinks(ctx, query, key, value, attn.Sinks, 1/math.Sqrt(float64(opts.headDim())), cache) attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize) diff --git a/model/models/llama/model.go b/model/models/llama/model.go index 52c66ba57..5ff4894e4 100644 --- a/model/models/llama/model.go +++ b/model/models/llama/model.go @@ -8,7 +8,6 @@ import ( "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" - "github.com/ollama/ollama/ml/nn/fast" "github.com/ollama/ollama/ml/nn/rope" "github.com/ollama/ollama/model" "github.com/ollama/ollama/model/input" @@ -20,6 +19,10 @@ type Options struct { eps, ropeBase, ropeScale float32 } +func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions, factors ml.Tensor) ml.Tensor { + return nn.RoPE(ctx, states, positions, cmp.Or(o.ropeDim, o.headDim, o.hiddenSize/o.numHeads), o.ropeBase, 1./o.ropeScale, rope.WithFactors(factors)) +} + type Model struct { model.Base model.TextProcessor @@ -115,7 +118,6 @@ type SelfAttention struct { func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { batchSize := hiddenState.Dim(1) headDim := cmp.Or(opts.headDim, opts.hiddenSize/opts.numHeads) - ropeDim := cmp.Or(opts.ropeDim, headDim) query := sa.Query.Forward(ctx, hiddenState) query = query.Reshape(ctx, headDim, opts.numHeads, batchSize) @@ -126,8 +128,8 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tenso value := sa.Value.Forward(ctx, hiddenState) value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors)) - key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors)) + query = opts.applyRotaryPositionEmbeddings(ctx, query, positions, sa.RopeFactors) + key = opts.applyRotaryPositionEmbeddings(ctx, key, positions, sa.RopeFactors) attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache) attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize) @@ -135,8 +137,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tenso } func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - ropeDim := cmp.Or(m.ropeDim, m.hiddenSize/m.numHeads) - return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithFactors(m.Layers[layer].SelfAttention.RopeFactors)), nil + return m.applyRotaryPositionEmbeddings(ctx, key, shift, m.Layers[layer].SelfAttention.RopeFactors), nil } type MLP struct { diff --git a/model/models/llama4/model_text.go b/model/models/llama4/model_text.go index 96b5d24d8..c2bf06148 100644 --- a/model/models/llama4/model_text.go +++ b/model/models/llama4/model_text.go @@ -8,7 +8,6 @@ import ( "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" - "github.com/ollama/ollama/ml/nn/fast" "github.com/ollama/ollama/ml/nn/rope" "github.com/ollama/ollama/model/input" ) @@ -33,8 +32,8 @@ func (sa *TextAttention) Forward(ctx ml.Context, hiddenStates, positions, attent value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) if useRope { - query = fast.RoPE(ctx, query, positions, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors)) - key = fast.RoPE(ctx, key, positions, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors)) + query = opts.applyRotaryPositionEmbeddings(ctx, query, positions, sa.RopeFactors) + key = opts.applyRotaryPositionEmbeddings(ctx, key, positions, sa.RopeFactors) } if opts.useQKNorm { @@ -152,6 +151,10 @@ type TextOptions struct { attentionFloorScale float64 } +func (o TextOptions) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions, factors ml.Tensor) ml.Tensor { + return nn.RoPE(ctx, states, positions, o.ropeDim, o.ropeBase, 1./o.ropeScale, rope.WithFactors(factors)) +} + type TextModel struct { Layers []TextLayer `gguf:"blk"` @@ -236,5 +239,5 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor } func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithFactors(m.Layers[layer].Attention.RopeFactors)), nil + return m.applyRotaryPositionEmbeddings(ctx, key, shift, m.Layers[layer].Attention.RopeFactors), nil } diff --git a/model/models/mistral3/model_text.go b/model/models/mistral3/model_text.go index 624d31510..ebb7b3aa1 100644 --- a/model/models/mistral3/model_text.go +++ b/model/models/mistral3/model_text.go @@ -8,7 +8,6 @@ import ( "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" - "github.com/ollama/ollama/ml/nn/fast" "github.com/ollama/ollama/model/input" ) @@ -20,6 +19,10 @@ type TextOptions struct { ropeScalingBeta float32 } +func (o TextOptions) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor { + return nn.RoPE(ctx, states, positions, o.ropeDim, o.ropeBase, 1./o.ropeScale) +} + type TextModel struct { TokenEmbedding *nn.Embedding `gguf:"token_embd"` Layers []Layer `gguf:"blk"` @@ -42,11 +45,11 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs, posit q := sa.Query.Forward(ctx, hiddenState) q = q.Reshape(ctx, headDim, opts.numHeads, batchSize) - q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale) + q = opts.applyRotaryPositionEmbeddings(ctx, q, positionIDs) k := sa.Key.Forward(ctx, hiddenState) k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale) + k = opts.applyRotaryPositionEmbeddings(ctx, k, positionIDs) v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) @@ -61,7 +64,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs, posit } func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, 1./m.ropeScale), nil + return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil } type MLP struct { diff --git a/model/models/mistral3/model_vision.go b/model/models/mistral3/model_vision.go index d763df7a0..1de0412d5 100644 --- a/model/models/mistral3/model_vision.go +++ b/model/models/mistral3/model_vision.go @@ -16,8 +16,8 @@ func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor { return x2.Scale(ctx, -1).Concat(ctx, x1, 0) } -func applyRotaryPositionalEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor { - return t.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, t).Mul(ctx, sin)) +func applyRotaryPositionEmbeddings(ctx ml.Context, states, cos, sin ml.Tensor) ml.Tensor { + return states.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, states).Mul(ctx, sin)) } type VisionSelfAttention struct { @@ -36,8 +36,8 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml key = key.Reshape(ctx, opts.headDim, opts.numHeads, key.Dim(1), batchSize) value = value.Reshape(ctx, opts.headDim, opts.numHeads, value.Dim(1), batchSize) - query = applyRotaryPositionalEmbedding(ctx, query, cos, sin) - key = applyRotaryPositionalEmbedding(ctx, key, cos, sin) + query = applyRotaryPositionEmbeddings(ctx, query, cos, sin) + key = applyRotaryPositionEmbeddings(ctx, key, cos, sin) attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim)), nil) attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize) diff --git a/model/models/mllama/model_text.go b/model/models/mllama/model_text.go index 65f0a8278..afd674eb9 100644 --- a/model/models/mllama/model_text.go +++ b/model/models/mllama/model_text.go @@ -8,7 +8,6 @@ import ( "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" - "github.com/ollama/ollama/ml/nn/fast" "github.com/ollama/ollama/ml/nn/rope" ) @@ -26,11 +25,11 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.T query := sa.Query.Forward(ctx, hiddenState) query = query.Reshape(ctx, headDim, opts.numHeads, batchSize) - query = fast.RoPE(ctx, query, positions, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors)) + query = opts.applyRotaryPositionEmbeddings(ctx, query, positions, sa.RopeFactors) key := sa.Key.Forward(ctx, hiddenState) key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - key = fast.RoPE(ctx, key, positions, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors)) + key = opts.applyRotaryPositionEmbeddings(ctx, key, positions, sa.RopeFactors) value := sa.Value.Forward(ctx, hiddenState) value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) @@ -44,8 +43,8 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.T func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { // This will only get called for layers in the cache, which are just the self attention layers - if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok { - return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithFactors(sa.SelfAttention.RopeFactors)), nil + if layer, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok { + return m.applyRotaryPositionEmbeddings(ctx, key, shift, layer.SelfAttention.RopeFactors), nil } return key, nil @@ -206,6 +205,10 @@ type TextModelOptions struct { crossAttentionLayers []int32 } +func (o TextModelOptions) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions, factors ml.Tensor) ml.Tensor { + return nn.RoPE(ctx, states, positions, o.ropeDim, o.ropeBase, 1./o.ropeScale, rope.WithFactors(factors)) +} + type TextModel struct { TokenEmbedding *nn.Embedding `gguf:"token_embd"` Transformer *TextDecoder `gguf:"blk"` diff --git a/model/models/nomicbert/model.go b/model/models/nomicbert/model.go index 0e742dfa1..2510240d6 100644 --- a/model/models/nomicbert/model.go +++ b/model/models/nomicbert/model.go @@ -7,7 +7,6 @@ import ( "github.com/ollama/ollama/fs" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" - "github.com/ollama/ollama/ml/nn/fast" "github.com/ollama/ollama/ml/nn/pooling" "github.com/ollama/ollama/ml/nn/rope" "github.com/ollama/ollama/model" @@ -37,6 +36,10 @@ type Options struct { ropeFreqBase float32 } +func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor { + return nn.RoPE(ctx, states, positions, o.headDim, o.ropeFreqBase, 1.0, rope.WithTypeNeoX()) +} + // Single Encoder Layer type EncoderLayer struct { *Attention @@ -105,8 +108,8 @@ func (a *Attention) Forward(ctx ml.Context, hiddenStates ml.Tensor, positions ml chunks := qkv.Chunk(ctx, 1, opts.numHeads) query, key, value := chunks[0], chunks[1], chunks[2] - query = fast.RoPE(ctx, query, positions, opts.headDim, opts.ropeFreqBase, 1.0, rope.WithTypeNeoX()) - key = fast.RoPE(ctx, key, positions, opts.headDim, opts.ropeFreqBase, 1.0, rope.WithTypeNeoX()) + query = opts.applyRotaryPositionEmbeddings(ctx, query, positions) + key = opts.applyRotaryPositionEmbeddings(ctx, key, positions) attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(opts.headDim)), nil) diff --git a/model/models/qwen2/model.go b/model/models/qwen2/model.go index 10a1e65cf..66f546ae6 100644 --- a/model/models/qwen2/model.go +++ b/model/models/qwen2/model.go @@ -10,7 +10,6 @@ import ( "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" - "github.com/ollama/ollama/ml/nn/fast" "github.com/ollama/ollama/ml/nn/rope" "github.com/ollama/ollama/model" "github.com/ollama/ollama/model/input" @@ -22,6 +21,10 @@ type Options struct { eps, ropeBase, ropeScale float32 } +func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor { + return nn.RoPE(ctx, states, positions, cmp.Or(o.ropeDim, o.headDim, o.hiddenSize/o.numHeads), o.ropeBase, 1./o.ropeScale, rope.WithTypeNeoX()) +} + type Attention struct { Query *nn.Linear `gguf:"attn_q"` Key *nn.Linear `gguf:"attn_k"` @@ -32,7 +35,6 @@ type Attention struct { func (attn Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { batchSize := hiddenStates.Dim(1) headDim := cmp.Or(opts.headDim, opts.hiddenSize/opts.numHeads) - ropeDim := cmp.Or(opts.ropeDim, headDim) query := attn.Query.Forward(ctx, hiddenStates) query = query.Reshape(ctx, headDim, opts.numHeads, batchSize) @@ -43,8 +45,8 @@ func (attn Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, value := attn.Value.Forward(ctx, hiddenStates) value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX()) - key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX()) + query = opts.applyRotaryPositionEmbeddings(ctx, query, positions) + key = opts.applyRotaryPositionEmbeddings(ctx, key, positions) attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache) attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize) @@ -123,8 +125,7 @@ func (m Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { } func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - ropeDim := cmp.Or(m.ropeDim, m.hiddenSize/m.numHeads) - return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithTypeNeoX()), nil + return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil } func New(c fs.Config) (model.Model, error) { diff --git a/model/models/qwen25vl/model_text.go b/model/models/qwen25vl/model_text.go index e6c6e6c19..b4db6043e 100644 --- a/model/models/qwen25vl/model_text.go +++ b/model/models/qwen25vl/model_text.go @@ -7,7 +7,6 @@ import ( "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" - "github.com/ollama/ollama/ml/nn/fast" "github.com/ollama/ollama/ml/nn/rope" "github.com/ollama/ollama/model/input" ) @@ -18,6 +17,13 @@ type TextOptions struct { eps, ropeBase, ropeScale float32 } +func (o TextOptions) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor { + return nn.RoPE(ctx, states, positions, o.ropeDim, o.ropeBase, 1./o.ropeScale, + rope.WithOriginalContextLength(o.originalContextLength), + rope.WithTypeNeoX(), + ) +} + type TextModel struct { TokenEmbedding *nn.Embedding `gguf:"token_embd"` Layers []Layer `gguf:"blk"` @@ -60,11 +66,11 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten q := sa.Query.Forward(ctx, hiddenState) q = q.Reshape(ctx, headDim, opts.numHeads, batchSize) - q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithOriginalContextLength(opts.originalContextLength), rope.WithTypeNeoX()) + q = opts.applyRotaryPositionEmbeddings(ctx, q, positionIDs) k := sa.Key.Forward(ctx, hiddenState) k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithOriginalContextLength(opts.originalContextLength), rope.WithTypeNeoX()) + k = opts.applyRotaryPositionEmbeddings(ctx, k, positionIDs) v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) @@ -78,7 +84,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten // Shift applies rotary position embeddings to the key tensor for causal attention caching func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithOriginalContextLength(m.originalContextLength), rope.WithTypeNeoX()), nil + return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil } // MLP implements the feed-forward network component with SwiGLU activation diff --git a/model/models/qwen25vl/model_vision.go b/model/models/qwen25vl/model_vision.go index 5cbb01f7e..bfdafabe4 100644 --- a/model/models/qwen25vl/model_vision.go +++ b/model/models/qwen25vl/model_vision.go @@ -18,8 +18,8 @@ func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor { return x2.Scale(ctx, -1).Concat(ctx, x1, 0) } -func applyRotaryPositionalEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor { - return t.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, t).Mul(ctx, sin)) +func applyRotaryPositionEmbeddings(ctx ml.Context, states, cos, sin ml.Tensor) ml.Tensor { + return states.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, states).Mul(ctx, sin)) } func blockDiagonalMask(ctx ml.Context, seqLength int, bounds []int, numHeads int) ml.Tensor { @@ -67,8 +67,8 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin, m key = key.Reshape(ctx, opts.headDim, opts.numHeads, key.Dim(1), batchSize) value = value.Reshape(ctx, opts.headDim, opts.numHeads, value.Dim(1), batchSize) - query = applyRotaryPositionalEmbedding(ctx, query, cos, sin) - key = applyRotaryPositionalEmbedding(ctx, key, cos, sin) + query = applyRotaryPositionEmbeddings(ctx, query, cos, sin) + key = applyRotaryPositionEmbeddings(ctx, key, cos, sin) // Scale factor for scaled dot-product attention scale := 1.0 / math.Sqrt(float64(opts.headDim)) diff --git a/model/models/qwen3/model.go b/model/models/qwen3/model.go index 483439ac4..d7747364e 100644 --- a/model/models/qwen3/model.go +++ b/model/models/qwen3/model.go @@ -9,7 +9,6 @@ import ( "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" - "github.com/ollama/ollama/ml/nn/fast" "github.com/ollama/ollama/ml/nn/rope" "github.com/ollama/ollama/model" "github.com/ollama/ollama/model/input" @@ -46,7 +45,7 @@ func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions rope.WithAttentionFactor(attnFactor), ) } - return fast.RoPE(ctx, states, positions, o.headDim(), o.ropeBase, 1./o.ropeScale, opts...) + return nn.RoPE(ctx, states, positions, o.headDim(), o.ropeBase, 1./o.ropeScale, opts...) } type Attention struct { diff --git a/model/models/qwen3vl/model.go b/model/models/qwen3vl/model.go index 579863ae5..cb1ce8d2c 100644 --- a/model/models/qwen3vl/model.go +++ b/model/models/qwen3vl/model.go @@ -195,7 +195,7 @@ func New(c fs.Config) (model.Model, error) { m.Cache = kvcache.NewCausalCache(func(ctx ml.Context, layer int, key, positions ml.Tensor) (ml.Tensor, error) { m.positionCache = nil positions = positions.Repeat(ctx, 1, 4).Reshape(ctx, -1) - return m.Options.applyRotaryPositionalEmbedding(ctx, key, positions), nil + return m.Options.applyRotaryPositionEmbeddings(ctx, key, positions), nil }) return &m, nil } diff --git a/model/models/qwen3vl/model_text.go b/model/models/qwen3vl/model_text.go index 64a567b02..750c2473a 100644 --- a/model/models/qwen3vl/model_text.go +++ b/model/models/qwen3vl/model_text.go @@ -10,7 +10,6 @@ import ( "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" - "github.com/ollama/ollama/ml/nn/fast" "github.com/ollama/ollama/ml/nn/rope" "github.com/ollama/ollama/model" ) @@ -35,8 +34,8 @@ func (o TextOptions) headDim() int { return cmp.Or(o.keyLength, o.valueLength, o.hiddenSize/o.numHeads) } -func (o TextOptions) applyRotaryPositionalEmbedding(ctx ml.Context, t, p ml.Tensor) ml.Tensor { - return fast.RoPE(ctx, t, p, o.headDim(), o.ropeBase, 1/float32(math.Sqrt(float64(o.ropeScale))), +func (o TextOptions) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor { + return nn.RoPE(ctx, states, positions, o.headDim(), o.ropeBase, 1/float32(math.Sqrt(float64(o.ropeScale))), rope.WithInterleaveMRoPE(o.mropeSections), ) } @@ -64,8 +63,8 @@ func (sa *TextAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tens query = sa.QueryNorm.Forward(ctx, query, opts.eps) key = sa.KeyNorm.Forward(ctx, key, opts.eps) - query = opts.applyRotaryPositionalEmbedding(ctx, query, positions) - key = opts.applyRotaryPositionalEmbedding(ctx, key, positions) + query = opts.applyRotaryPositionEmbeddings(ctx, query, positions) + key = opts.applyRotaryPositionEmbeddings(ctx, key, positions) attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache) attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize) diff --git a/model/models/qwen3vl/model_vision.go b/model/models/qwen3vl/model_vision.go index b22ac305c..761281edc 100644 --- a/model/models/qwen3vl/model_vision.go +++ b/model/models/qwen3vl/model_vision.go @@ -23,18 +23,18 @@ func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor { return x2.Scale(ctx, -1).Concat(ctx, x1, 0) } -func applyRotaryPositionalEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor { - return t.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, t).Mul(ctx, sin)) +func applyRotaryPositionEmbeddings(ctx ml.Context, states, cos, sin ml.Tensor) ml.Tensor { + return states.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, states).Mul(ctx, sin)) } func (sa *VisionAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, opts VisionOptions) ml.Tensor { query := sa.Query.Forward(ctx, hiddenStates) query = query.Reshape(ctx, opts.headDim(), opts.numHeads, query.Dim(1)) - query = applyRotaryPositionalEmbedding(ctx, query, cos, sin) + query = applyRotaryPositionEmbeddings(ctx, query, cos, sin) key := sa.Key.Forward(ctx, hiddenStates) key = key.Reshape(ctx, opts.headDim(), opts.numHeads, key.Dim(1)) - key = applyRotaryPositionalEmbedding(ctx, key, cos, sin) + key = applyRotaryPositionEmbeddings(ctx, key, cos, sin) value := sa.Value.Forward(ctx, hiddenStates) value = value.Reshape(ctx, opts.headDim(), opts.numHeads, value.Dim(1))