mirror of https://github.com/ollama/ollama
refactor rope
change to a flatter directory structure and group the options with the function update models to call rope in one place
This commit is contained in:
parent
e082d60a24
commit
603ceefaa6
|
|
@ -1,5 +1,4 @@
|
||||||
// fast provides implementations of fast (fused) operations for increased performance.
|
package nn
|
||||||
package fast
|
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
|
|
@ -8,7 +7,7 @@ import (
|
||||||
|
|
||||||
// fastRoPE is an interface for tensors that support fast rotary positional embedding.
|
// fastRoPE is an interface for tensors that support fast rotary positional embedding.
|
||||||
type fastRoPE interface {
|
type fastRoPE interface {
|
||||||
RoPE(ctx ml.Context, positionIDs ml.Tensor, dim int, base, scale float32, options ...func(*rope.Options)) ml.Tensor
|
RoPE(ctx ml.Context, positions ml.Tensor, dim int, base, scale float32, options ...func(*rope.Options)) ml.Tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
// RoPE applies rotary positional embedding to tensor `t`.
|
// RoPE applies rotary positional embedding to tensor `t`.
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
// Package rope provides options for RoPE
|
||||||
package rope
|
package rope
|
||||||
|
|
||||||
import "github.com/ollama/ollama/ml"
|
import "github.com/ollama/ollama/ml"
|
||||||
|
|
@ -10,7 +10,6 @@ import (
|
||||||
"github.com/ollama/ollama/kvcache"
|
"github.com/ollama/ollama/kvcache"
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
"github.com/ollama/ollama/ml/nn"
|
"github.com/ollama/ollama/ml/nn"
|
||||||
"github.com/ollama/ollama/ml/nn/fast"
|
|
||||||
"github.com/ollama/ollama/ml/nn/rope"
|
"github.com/ollama/ollama/ml/nn/rope"
|
||||||
"github.com/ollama/ollama/model"
|
"github.com/ollama/ollama/model"
|
||||||
"github.com/ollama/ollama/model/input"
|
"github.com/ollama/ollama/model/input"
|
||||||
|
|
@ -42,13 +41,12 @@ type Options struct {
|
||||||
kqScale float64
|
kqScale float64
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o Options) RoPEOptions() []func(*rope.Options) {
|
func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, t, p ml.Tensor) ml.Tensor {
|
||||||
attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(o.ropeScale))))
|
return nn.RoPE(ctx, t, p, o.qkRopeHeadDim, o.ropeBase, 1./o.ropeScale,
|
||||||
return []func(*rope.Options){
|
|
||||||
rope.WithOriginalContextLength(o.originalContextLength),
|
rope.WithOriginalContextLength(o.originalContextLength),
|
||||||
rope.WithExtrapolationFactor(1.),
|
rope.WithExtrapolationFactor(1.),
|
||||||
rope.WithAttentionFactor(attnFactor),
|
rope.WithAttentionFactor(float32(1.0/(1.0+0.1*math.Log(float64(o.ropeScale))))),
|
||||||
}
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Attention struct {
|
type Attention struct {
|
||||||
|
|
@ -91,8 +89,8 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor
|
||||||
compressedKV.Stride(1), compressedKV.Dim(1),
|
compressedKV.Stride(1), compressedKV.Dim(1),
|
||||||
)
|
)
|
||||||
|
|
||||||
qRot := fast.RoPE(ctx, queryChunks[1], positions, opts.qkRopeHeadDim, opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...)
|
qRot := opts.applyRotaryPositionEmbeddings(ctx, queryChunks[1], positions)
|
||||||
kRot = fast.RoPE(ctx, kRot, positions, opts.qkRopeHeadDim, opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...)
|
kRot = opts.applyRotaryPositionEmbeddings(ctx, kRot, positions)
|
||||||
kPass = attn.KVANorm.Forward(ctx, kPass, opts.eps)
|
kPass = attn.KVANorm.Forward(ctx, kPass, opts.eps)
|
||||||
|
|
||||||
var attention ml.Tensor
|
var attention ml.Tensor
|
||||||
|
|
@ -327,7 +325,7 @@ func New(c fs.Config) (model.Model, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
return fast.RoPE(ctx, key, shift, m.qkRopeHeadDim, m.ropeBase, 1./m.ropeScale, m.RoPEOptions()...), nil
|
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,6 @@ import (
|
||||||
"github.com/ollama/ollama/kvcache"
|
"github.com/ollama/ollama/kvcache"
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
"github.com/ollama/ollama/ml/nn"
|
"github.com/ollama/ollama/ml/nn"
|
||||||
"github.com/ollama/ollama/ml/nn/fast"
|
|
||||||
"github.com/ollama/ollama/ml/nn/rope"
|
"github.com/ollama/ollama/ml/nn/rope"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -20,7 +19,7 @@ type textModel struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *textModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
func (m *textModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
return m.Options.applyRotaryPositionalEmbedding(ctx, key, shift), nil
|
return m.Options.applyRotaryPositionEmbeddings(ctx, key, shift), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type textOptions struct {
|
type textOptions struct {
|
||||||
|
|
@ -38,8 +37,8 @@ func (o textOptions) headDim() int {
|
||||||
return o.hiddenSize / o.numHeads
|
return o.hiddenSize / o.numHeads
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o textOptions) applyRotaryPositionalEmbedding(ctx ml.Context, t, p ml.Tensor) ml.Tensor {
|
func (o textOptions) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
|
||||||
return fast.RoPE(ctx, t, p, o.headDim(), o.ropeBase, 1/o.ropeScale, rope.WithTypeNeoX())
|
return nn.RoPE(ctx, states, positions, o.headDim(), o.ropeBase, 1/o.ropeScale, rope.WithTypeNeoX())
|
||||||
}
|
}
|
||||||
|
|
||||||
type textBlock struct {
|
type textBlock struct {
|
||||||
|
|
@ -83,8 +82,8 @@ func (m *textAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tenso
|
||||||
value := m.Value.Forward(ctx, hiddenStates)
|
value := m.Value.Forward(ctx, hiddenStates)
|
||||||
value = value.Reshape(ctx, opts.headDim(), opts.numKVHeads, -1)
|
value = value.Reshape(ctx, opts.headDim(), opts.numKVHeads, -1)
|
||||||
|
|
||||||
query = opts.applyRotaryPositionalEmbedding(ctx, query, positions)
|
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
|
||||||
key = opts.applyRotaryPositionalEmbedding(ctx, key, positions)
|
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
|
||||||
|
|
||||||
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache)
|
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache)
|
||||||
attention = attention.Reshape(ctx, -1, attention.Dim(2))
|
attention = attention.Reshape(ctx, -1, attention.Dim(2))
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,6 @@ import (
|
||||||
"github.com/ollama/ollama/kvcache"
|
"github.com/ollama/ollama/kvcache"
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
"github.com/ollama/ollama/ml/nn"
|
"github.com/ollama/ollama/ml/nn"
|
||||||
"github.com/ollama/ollama/ml/nn/fast"
|
|
||||||
"github.com/ollama/ollama/ml/nn/rope"
|
"github.com/ollama/ollama/ml/nn/rope"
|
||||||
"github.com/ollama/ollama/model"
|
"github.com/ollama/ollama/model"
|
||||||
"github.com/ollama/ollama/model/input"
|
"github.com/ollama/ollama/model/input"
|
||||||
|
|
@ -22,6 +21,10 @@ type Options struct {
|
||||||
largeModelScaling bool
|
largeModelScaling bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
|
||||||
|
return nn.RoPE(ctx, states, positions, o.attnKeyLen, o.ropeBase, 1./o.ropeScale, rope.WithTypeNeoX())
|
||||||
|
}
|
||||||
|
|
||||||
type Model struct {
|
type Model struct {
|
||||||
model.Base
|
model.Base
|
||||||
model.SentencePiece
|
model.SentencePiece
|
||||||
|
|
@ -88,7 +91,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||||
|
|
||||||
q := sa.Query.Forward(ctx, hiddenState)
|
q := sa.Query.Forward(ctx, hiddenState)
|
||||||
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
|
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
|
||||||
q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
|
q = opts.applyRotaryPositionEmbeddings(ctx, q, positionIDs)
|
||||||
|
|
||||||
if opts.largeModelScaling {
|
if opts.largeModelScaling {
|
||||||
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
|
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
|
||||||
|
|
@ -98,7 +101,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||||
|
|
||||||
k := sa.Key.Forward(ctx, hiddenState)
|
k := sa.Key.Forward(ctx, hiddenState)
|
||||||
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
|
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
|
||||||
k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
|
k = opts.applyRotaryPositionEmbeddings(ctx, k, positionIDs)
|
||||||
|
|
||||||
v := sa.Value.Forward(ctx, hiddenState)
|
v := sa.Value.Forward(ctx, hiddenState)
|
||||||
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
|
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
|
||||||
|
|
@ -128,7 +131,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
return fast.RoPE(ctx, key, shift, m.Options.attnKeyLen, m.Options.ropeBase, 1/m.Options.ropeScale, rope.WithTypeNeoX()), nil
|
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type MLP struct {
|
type MLP struct {
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,6 @@ import (
|
||||||
"github.com/ollama/ollama/kvcache"
|
"github.com/ollama/ollama/kvcache"
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
"github.com/ollama/ollama/ml/nn"
|
"github.com/ollama/ollama/ml/nn"
|
||||||
"github.com/ollama/ollama/ml/nn/fast"
|
|
||||||
"github.com/ollama/ollama/ml/nn/rope"
|
"github.com/ollama/ollama/ml/nn/rope"
|
||||||
"github.com/ollama/ollama/model/input"
|
"github.com/ollama/ollama/model/input"
|
||||||
)
|
)
|
||||||
|
|
@ -20,6 +19,10 @@ type TextConfig struct {
|
||||||
largeModelScaling bool
|
largeModelScaling bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (o TextConfig) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor, base float32) ml.Tensor {
|
||||||
|
return nn.RoPE(ctx, states, positions, o.attnKeyLen, base, 1./o.ropeScale, rope.WithTypeNeoX())
|
||||||
|
}
|
||||||
|
|
||||||
type TextModel struct {
|
type TextModel struct {
|
||||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||||
Layers []TextLayer `gguf:"blk"`
|
Layers []TextLayer `gguf:"blk"`
|
||||||
|
|
@ -87,7 +90,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
|
||||||
q := sa.Query.Forward(ctx, hiddenState)
|
q := sa.Query.Forward(ctx, hiddenState)
|
||||||
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
|
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
|
||||||
q = sa.QueryNorm.Forward(ctx, q, opts.eps)
|
q = sa.QueryNorm.Forward(ctx, q, opts.eps)
|
||||||
q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
|
q = opts.applyRotaryPositionEmbeddings(ctx, q, positionIDs, ropeBase)
|
||||||
|
|
||||||
if opts.largeModelScaling {
|
if opts.largeModelScaling {
|
||||||
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
|
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
|
||||||
|
|
@ -98,7 +101,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
|
||||||
k := sa.Key.Forward(ctx, hiddenState)
|
k := sa.Key.Forward(ctx, hiddenState)
|
||||||
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
|
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
|
||||||
k = sa.KeyNorm.Forward(ctx, k, opts.eps)
|
k = sa.KeyNorm.Forward(ctx, k, opts.eps)
|
||||||
k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
|
k = opts.applyRotaryPositionEmbeddings(ctx, k, positionIDs, ropeBase)
|
||||||
|
|
||||||
v := sa.Value.Forward(ctx, hiddenState)
|
v := sa.Value.Forward(ctx, hiddenState)
|
||||||
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
|
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
|
||||||
|
|
@ -116,7 +119,7 @@ func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.T
|
||||||
ropeBase = m.TextConfig.ropeGlobalBase
|
ropeBase = m.TextConfig.ropeGlobalBase
|
||||||
}
|
}
|
||||||
|
|
||||||
return fast.RoPE(ctx, key, shift, m.TextConfig.attnKeyLen, ropeBase, 1/m.TextConfig.ropeScale, rope.WithTypeNeoX()), nil
|
return m.applyRotaryPositionEmbeddings(ctx, key, shift, ropeBase), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type TextMLP struct {
|
type TextMLP struct {
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,6 @@ import (
|
||||||
"github.com/ollama/ollama/kvcache"
|
"github.com/ollama/ollama/kvcache"
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
"github.com/ollama/ollama/ml/nn"
|
"github.com/ollama/ollama/ml/nn"
|
||||||
"github.com/ollama/ollama/ml/nn/fast"
|
|
||||||
"github.com/ollama/ollama/ml/nn/rope"
|
"github.com/ollama/ollama/ml/nn/rope"
|
||||||
"github.com/ollama/ollama/model/input"
|
"github.com/ollama/ollama/model/input"
|
||||||
)
|
)
|
||||||
|
|
@ -95,7 +94,7 @@ func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.T
|
||||||
ropeBase = m.ropeBaseLocal
|
ropeBase = m.ropeBaseLocal
|
||||||
}
|
}
|
||||||
|
|
||||||
return fast.RoPE(ctx, key, shift, m.headDim(), ropeBase, 1./m.ropeScale, rope.WithTypeNeoX()), nil
|
return m.applyRotaryPositionEmbeddings(ctx, key, shift, ropeBase), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type TextScaledWordEmbedding struct {
|
type TextScaledWordEmbedding struct {
|
||||||
|
|
@ -256,14 +255,14 @@ func (attn TextAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Ten
|
||||||
query := attn.Query.Forward(ctx, hiddenStates)
|
query := attn.Query.Forward(ctx, hiddenStates)
|
||||||
query = query.Reshape(ctx, opts.headDim(), opts.numHeads, batchSize)
|
query = query.Reshape(ctx, opts.headDim(), opts.numHeads, batchSize)
|
||||||
query = attn.QueryNorm.Forward(ctx, query, opts.eps)
|
query = attn.QueryNorm.Forward(ctx, query, opts.eps)
|
||||||
query = fast.RoPE(ctx, query, positions, opts.headDim(), ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
|
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions, ropeBase)
|
||||||
|
|
||||||
var key, value ml.Tensor
|
var key, value ml.Tensor
|
||||||
if !sharedKV {
|
if !sharedKV {
|
||||||
key = attn.Key.Forward(ctx, hiddenStates)
|
key = attn.Key.Forward(ctx, hiddenStates)
|
||||||
key = key.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize)
|
key = key.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize)
|
||||||
key = attn.KeyNorm.Forward(ctx, key, opts.eps)
|
key = attn.KeyNorm.Forward(ctx, key, opts.eps)
|
||||||
key = fast.RoPE(ctx, key, positions, opts.headDim(), ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
|
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions, ropeBase)
|
||||||
|
|
||||||
value = attn.Value.Forward(ctx, hiddenStates)
|
value = attn.Value.Forward(ctx, hiddenStates)
|
||||||
value = value.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize)
|
value = value.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize)
|
||||||
|
|
@ -330,6 +329,10 @@ func (o *TextOptions) isLocal(i int) bool {
|
||||||
return o.slidingWindowPattern[i]
|
return o.slidingWindowPattern[i]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (o TextOptions) applyRotaryPositionEmbeddings(ctx ml.Context, t, p ml.Tensor, base float32) ml.Tensor {
|
||||||
|
return nn.RoPE(ctx, t, p, o.headDim(), base, 1./o.ropeScale, rope.WithTypeNeoX())
|
||||||
|
}
|
||||||
|
|
||||||
func newTextModel(c fs.Config) *TextModel {
|
func newTextModel(c fs.Config) *TextModel {
|
||||||
return &TextModel{
|
return &TextModel{
|
||||||
TextLayers: make([]TextLayer, c.Uint("block_count")),
|
TextLayers: make([]TextLayer, c.Uint("block_count")),
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,6 @@ import (
|
||||||
"github.com/ollama/ollama/kvcache"
|
"github.com/ollama/ollama/kvcache"
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
"github.com/ollama/ollama/ml/nn"
|
"github.com/ollama/ollama/ml/nn"
|
||||||
"github.com/ollama/ollama/ml/nn/fast"
|
|
||||||
"github.com/ollama/ollama/ml/nn/rope"
|
"github.com/ollama/ollama/ml/nn/rope"
|
||||||
"github.com/ollama/ollama/model"
|
"github.com/ollama/ollama/model"
|
||||||
"github.com/ollama/ollama/model/input"
|
"github.com/ollama/ollama/model/input"
|
||||||
|
|
@ -52,7 +51,7 @@ func (m *Transformer) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Transformer) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
func (m *Transformer) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
return fast.RoPE(ctx, key, shift, m.headDim(), m.ropeBase, 1./m.ropeScale, m.RoPEOptions()...), nil
|
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type Options struct {
|
type Options struct {
|
||||||
|
|
@ -70,14 +69,14 @@ type Options struct {
|
||||||
ropeScale float32
|
ropeScale float32
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o Options) RoPEOptions() []func(*rope.Options) {
|
func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
|
||||||
return []func(*rope.Options){
|
return nn.RoPE(ctx, states, positions, o.headDim(), o.ropeBase, 1./o.ropeScale,
|
||||||
rope.WithTypeNeoX(),
|
rope.WithTypeNeoX(),
|
||||||
rope.WithOriginalContextLength(o.originalContextLength),
|
rope.WithOriginalContextLength(o.originalContextLength),
|
||||||
rope.WithExtrapolationFactor(1.),
|
rope.WithExtrapolationFactor(1.),
|
||||||
// NOTE: ggml sets this implicitly so there's no need to set it here
|
// NOTE: ggml sets this implicitly so there's no need to set it here
|
||||||
// rope.WithAttentionFactor(0.1*float32(math.Log(float64(o.ropeScale))) + 1.0),
|
// rope.WithAttentionFactor(0.1*float32(math.Log(float64(o.ropeScale))) + 1.0),
|
||||||
}
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o Options) headDim() int {
|
func (o Options) headDim() int {
|
||||||
|
|
@ -135,8 +134,8 @@ func (attn *AttentionBlock) Forward(ctx ml.Context, hiddenStates, positions ml.T
|
||||||
value = value.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize)
|
value = value.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize)
|
||||||
}
|
}
|
||||||
|
|
||||||
query = fast.RoPE(ctx, query, positions, opts.headDim(), opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...)
|
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
|
||||||
key = fast.RoPE(ctx, key, positions, opts.headDim(), opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...)
|
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
|
||||||
|
|
||||||
attention := nn.AttentionWithSinks(ctx, query, key, value, attn.Sinks, 1/math.Sqrt(float64(opts.headDim())), cache)
|
attention := nn.AttentionWithSinks(ctx, query, key, value, attn.Sinks, 1/math.Sqrt(float64(opts.headDim())), cache)
|
||||||
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize)
|
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize)
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,6 @@ import (
|
||||||
"github.com/ollama/ollama/kvcache"
|
"github.com/ollama/ollama/kvcache"
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
"github.com/ollama/ollama/ml/nn"
|
"github.com/ollama/ollama/ml/nn"
|
||||||
"github.com/ollama/ollama/ml/nn/fast"
|
|
||||||
"github.com/ollama/ollama/ml/nn/rope"
|
"github.com/ollama/ollama/ml/nn/rope"
|
||||||
"github.com/ollama/ollama/model"
|
"github.com/ollama/ollama/model"
|
||||||
"github.com/ollama/ollama/model/input"
|
"github.com/ollama/ollama/model/input"
|
||||||
|
|
@ -20,6 +19,10 @@ type Options struct {
|
||||||
eps, ropeBase, ropeScale float32
|
eps, ropeBase, ropeScale float32
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions, factors ml.Tensor) ml.Tensor {
|
||||||
|
return nn.RoPE(ctx, states, positions, cmp.Or(o.ropeDim, o.headDim, o.hiddenSize/o.numHeads), o.ropeBase, 1./o.ropeScale, rope.WithFactors(factors))
|
||||||
|
}
|
||||||
|
|
||||||
type Model struct {
|
type Model struct {
|
||||||
model.Base
|
model.Base
|
||||||
model.TextProcessor
|
model.TextProcessor
|
||||||
|
|
@ -115,7 +118,6 @@ type SelfAttention struct {
|
||||||
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||||
batchSize := hiddenState.Dim(1)
|
batchSize := hiddenState.Dim(1)
|
||||||
headDim := cmp.Or(opts.headDim, opts.hiddenSize/opts.numHeads)
|
headDim := cmp.Or(opts.headDim, opts.hiddenSize/opts.numHeads)
|
||||||
ropeDim := cmp.Or(opts.ropeDim, headDim)
|
|
||||||
|
|
||||||
query := sa.Query.Forward(ctx, hiddenState)
|
query := sa.Query.Forward(ctx, hiddenState)
|
||||||
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||||
|
|
@ -126,8 +128,8 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tenso
|
||||||
value := sa.Value.Forward(ctx, hiddenState)
|
value := sa.Value.Forward(ctx, hiddenState)
|
||||||
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
|
|
||||||
query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors))
|
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions, sa.RopeFactors)
|
||||||
key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors))
|
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions, sa.RopeFactors)
|
||||||
|
|
||||||
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache)
|
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache)
|
||||||
attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize)
|
attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize)
|
||||||
|
|
@ -135,8 +137,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tenso
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
ropeDim := cmp.Or(m.ropeDim, m.hiddenSize/m.numHeads)
|
return m.applyRotaryPositionEmbeddings(ctx, key, shift, m.Layers[layer].SelfAttention.RopeFactors), nil
|
||||||
return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithFactors(m.Layers[layer].SelfAttention.RopeFactors)), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type MLP struct {
|
type MLP struct {
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,6 @@ import (
|
||||||
"github.com/ollama/ollama/kvcache"
|
"github.com/ollama/ollama/kvcache"
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
"github.com/ollama/ollama/ml/nn"
|
"github.com/ollama/ollama/ml/nn"
|
||||||
"github.com/ollama/ollama/ml/nn/fast"
|
|
||||||
"github.com/ollama/ollama/ml/nn/rope"
|
"github.com/ollama/ollama/ml/nn/rope"
|
||||||
"github.com/ollama/ollama/model/input"
|
"github.com/ollama/ollama/model/input"
|
||||||
)
|
)
|
||||||
|
|
@ -33,8 +32,8 @@ func (sa *TextAttention) Forward(ctx ml.Context, hiddenStates, positions, attent
|
||||||
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
|
|
||||||
if useRope {
|
if useRope {
|
||||||
query = fast.RoPE(ctx, query, positions, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors))
|
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions, sa.RopeFactors)
|
||||||
key = fast.RoPE(ctx, key, positions, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors))
|
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions, sa.RopeFactors)
|
||||||
}
|
}
|
||||||
|
|
||||||
if opts.useQKNorm {
|
if opts.useQKNorm {
|
||||||
|
|
@ -152,6 +151,10 @@ type TextOptions struct {
|
||||||
attentionFloorScale float64
|
attentionFloorScale float64
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (o TextOptions) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions, factors ml.Tensor) ml.Tensor {
|
||||||
|
return nn.RoPE(ctx, states, positions, o.ropeDim, o.ropeBase, 1./o.ropeScale, rope.WithFactors(factors))
|
||||||
|
}
|
||||||
|
|
||||||
type TextModel struct {
|
type TextModel struct {
|
||||||
Layers []TextLayer `gguf:"blk"`
|
Layers []TextLayer `gguf:"blk"`
|
||||||
|
|
||||||
|
|
@ -236,5 +239,5 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithFactors(m.Layers[layer].Attention.RopeFactors)), nil
|
return m.applyRotaryPositionEmbeddings(ctx, key, shift, m.Layers[layer].Attention.RopeFactors), nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,6 @@ import (
|
||||||
"github.com/ollama/ollama/kvcache"
|
"github.com/ollama/ollama/kvcache"
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
"github.com/ollama/ollama/ml/nn"
|
"github.com/ollama/ollama/ml/nn"
|
||||||
"github.com/ollama/ollama/ml/nn/fast"
|
|
||||||
"github.com/ollama/ollama/model/input"
|
"github.com/ollama/ollama/model/input"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -20,6 +19,10 @@ type TextOptions struct {
|
||||||
ropeScalingBeta float32
|
ropeScalingBeta float32
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (o TextOptions) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
|
||||||
|
return nn.RoPE(ctx, states, positions, o.ropeDim, o.ropeBase, 1./o.ropeScale)
|
||||||
|
}
|
||||||
|
|
||||||
type TextModel struct {
|
type TextModel struct {
|
||||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||||
Layers []Layer `gguf:"blk"`
|
Layers []Layer `gguf:"blk"`
|
||||||
|
|
@ -42,11 +45,11 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs, posit
|
||||||
|
|
||||||
q := sa.Query.Forward(ctx, hiddenState)
|
q := sa.Query.Forward(ctx, hiddenState)
|
||||||
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||||
q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale)
|
q = opts.applyRotaryPositionEmbeddings(ctx, q, positionIDs)
|
||||||
|
|
||||||
k := sa.Key.Forward(ctx, hiddenState)
|
k := sa.Key.Forward(ctx, hiddenState)
|
||||||
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale)
|
k = opts.applyRotaryPositionEmbeddings(ctx, k, positionIDs)
|
||||||
|
|
||||||
v := sa.Value.Forward(ctx, hiddenState)
|
v := sa.Value.Forward(ctx, hiddenState)
|
||||||
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
|
|
@ -61,7 +64,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs, posit
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, 1./m.ropeScale), nil
|
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type MLP struct {
|
type MLP struct {
|
||||||
|
|
|
||||||
|
|
@ -16,8 +16,8 @@ func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor {
|
||||||
return x2.Scale(ctx, -1).Concat(ctx, x1, 0)
|
return x2.Scale(ctx, -1).Concat(ctx, x1, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
func applyRotaryPositionalEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor {
|
func applyRotaryPositionEmbeddings(ctx ml.Context, states, cos, sin ml.Tensor) ml.Tensor {
|
||||||
return t.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, t).Mul(ctx, sin))
|
return states.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, states).Mul(ctx, sin))
|
||||||
}
|
}
|
||||||
|
|
||||||
type VisionSelfAttention struct {
|
type VisionSelfAttention struct {
|
||||||
|
|
@ -36,8 +36,8 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml
|
||||||
key = key.Reshape(ctx, opts.headDim, opts.numHeads, key.Dim(1), batchSize)
|
key = key.Reshape(ctx, opts.headDim, opts.numHeads, key.Dim(1), batchSize)
|
||||||
value = value.Reshape(ctx, opts.headDim, opts.numHeads, value.Dim(1), batchSize)
|
value = value.Reshape(ctx, opts.headDim, opts.numHeads, value.Dim(1), batchSize)
|
||||||
|
|
||||||
query = applyRotaryPositionalEmbedding(ctx, query, cos, sin)
|
query = applyRotaryPositionEmbeddings(ctx, query, cos, sin)
|
||||||
key = applyRotaryPositionalEmbedding(ctx, key, cos, sin)
|
key = applyRotaryPositionEmbeddings(ctx, key, cos, sin)
|
||||||
|
|
||||||
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim)), nil)
|
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim)), nil)
|
||||||
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
|
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,6 @@ import (
|
||||||
"github.com/ollama/ollama/kvcache"
|
"github.com/ollama/ollama/kvcache"
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
"github.com/ollama/ollama/ml/nn"
|
"github.com/ollama/ollama/ml/nn"
|
||||||
"github.com/ollama/ollama/ml/nn/fast"
|
|
||||||
"github.com/ollama/ollama/ml/nn/rope"
|
"github.com/ollama/ollama/ml/nn/rope"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -26,11 +25,11 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.T
|
||||||
|
|
||||||
query := sa.Query.Forward(ctx, hiddenState)
|
query := sa.Query.Forward(ctx, hiddenState)
|
||||||
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||||
query = fast.RoPE(ctx, query, positions, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors))
|
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions, sa.RopeFactors)
|
||||||
|
|
||||||
key := sa.Key.Forward(ctx, hiddenState)
|
key := sa.Key.Forward(ctx, hiddenState)
|
||||||
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
key = fast.RoPE(ctx, key, positions, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors))
|
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions, sa.RopeFactors)
|
||||||
|
|
||||||
value := sa.Value.Forward(ctx, hiddenState)
|
value := sa.Value.Forward(ctx, hiddenState)
|
||||||
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
|
|
@ -44,8 +43,8 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.T
|
||||||
|
|
||||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
// This will only get called for layers in the cache, which are just the self attention layers
|
// This will only get called for layers in the cache, which are just the self attention layers
|
||||||
if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok {
|
if layer, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok {
|
||||||
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithFactors(sa.SelfAttention.RopeFactors)), nil
|
return m.applyRotaryPositionEmbeddings(ctx, key, shift, layer.SelfAttention.RopeFactors), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return key, nil
|
return key, nil
|
||||||
|
|
@ -206,6 +205,10 @@ type TextModelOptions struct {
|
||||||
crossAttentionLayers []int32
|
crossAttentionLayers []int32
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (o TextModelOptions) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions, factors ml.Tensor) ml.Tensor {
|
||||||
|
return nn.RoPE(ctx, states, positions, o.ropeDim, o.ropeBase, 1./o.ropeScale, rope.WithFactors(factors))
|
||||||
|
}
|
||||||
|
|
||||||
type TextModel struct {
|
type TextModel struct {
|
||||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||||
Transformer *TextDecoder `gguf:"blk"`
|
Transformer *TextDecoder `gguf:"blk"`
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,6 @@ import (
|
||||||
"github.com/ollama/ollama/fs"
|
"github.com/ollama/ollama/fs"
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
"github.com/ollama/ollama/ml/nn"
|
"github.com/ollama/ollama/ml/nn"
|
||||||
"github.com/ollama/ollama/ml/nn/fast"
|
|
||||||
"github.com/ollama/ollama/ml/nn/pooling"
|
"github.com/ollama/ollama/ml/nn/pooling"
|
||||||
"github.com/ollama/ollama/ml/nn/rope"
|
"github.com/ollama/ollama/ml/nn/rope"
|
||||||
"github.com/ollama/ollama/model"
|
"github.com/ollama/ollama/model"
|
||||||
|
|
@ -37,6 +36,10 @@ type Options struct {
|
||||||
ropeFreqBase float32
|
ropeFreqBase float32
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
|
||||||
|
return nn.RoPE(ctx, states, positions, o.headDim, o.ropeFreqBase, 1.0, rope.WithTypeNeoX())
|
||||||
|
}
|
||||||
|
|
||||||
// Single Encoder Layer
|
// Single Encoder Layer
|
||||||
type EncoderLayer struct {
|
type EncoderLayer struct {
|
||||||
*Attention
|
*Attention
|
||||||
|
|
@ -105,8 +108,8 @@ func (a *Attention) Forward(ctx ml.Context, hiddenStates ml.Tensor, positions ml
|
||||||
chunks := qkv.Chunk(ctx, 1, opts.numHeads)
|
chunks := qkv.Chunk(ctx, 1, opts.numHeads)
|
||||||
query, key, value := chunks[0], chunks[1], chunks[2]
|
query, key, value := chunks[0], chunks[1], chunks[2]
|
||||||
|
|
||||||
query = fast.RoPE(ctx, query, positions, opts.headDim, opts.ropeFreqBase, 1.0, rope.WithTypeNeoX())
|
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
|
||||||
key = fast.RoPE(ctx, key, positions, opts.headDim, opts.ropeFreqBase, 1.0, rope.WithTypeNeoX())
|
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
|
||||||
|
|
||||||
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(opts.headDim)), nil)
|
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(opts.headDim)), nil)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,6 @@ import (
|
||||||
"github.com/ollama/ollama/kvcache"
|
"github.com/ollama/ollama/kvcache"
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
"github.com/ollama/ollama/ml/nn"
|
"github.com/ollama/ollama/ml/nn"
|
||||||
"github.com/ollama/ollama/ml/nn/fast"
|
|
||||||
"github.com/ollama/ollama/ml/nn/rope"
|
"github.com/ollama/ollama/ml/nn/rope"
|
||||||
"github.com/ollama/ollama/model"
|
"github.com/ollama/ollama/model"
|
||||||
"github.com/ollama/ollama/model/input"
|
"github.com/ollama/ollama/model/input"
|
||||||
|
|
@ -22,6 +21,10 @@ type Options struct {
|
||||||
eps, ropeBase, ropeScale float32
|
eps, ropeBase, ropeScale float32
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
|
||||||
|
return nn.RoPE(ctx, states, positions, cmp.Or(o.ropeDim, o.headDim, o.hiddenSize/o.numHeads), o.ropeBase, 1./o.ropeScale, rope.WithTypeNeoX())
|
||||||
|
}
|
||||||
|
|
||||||
type Attention struct {
|
type Attention struct {
|
||||||
Query *nn.Linear `gguf:"attn_q"`
|
Query *nn.Linear `gguf:"attn_q"`
|
||||||
Key *nn.Linear `gguf:"attn_k"`
|
Key *nn.Linear `gguf:"attn_k"`
|
||||||
|
|
@ -32,7 +35,6 @@ type Attention struct {
|
||||||
func (attn Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
func (attn Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||||
batchSize := hiddenStates.Dim(1)
|
batchSize := hiddenStates.Dim(1)
|
||||||
headDim := cmp.Or(opts.headDim, opts.hiddenSize/opts.numHeads)
|
headDim := cmp.Or(opts.headDim, opts.hiddenSize/opts.numHeads)
|
||||||
ropeDim := cmp.Or(opts.ropeDim, headDim)
|
|
||||||
|
|
||||||
query := attn.Query.Forward(ctx, hiddenStates)
|
query := attn.Query.Forward(ctx, hiddenStates)
|
||||||
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||||
|
|
@ -43,8 +45,8 @@ func (attn Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor,
|
||||||
value := attn.Value.Forward(ctx, hiddenStates)
|
value := attn.Value.Forward(ctx, hiddenStates)
|
||||||
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
|
|
||||||
query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
|
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
|
||||||
key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
|
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
|
||||||
|
|
||||||
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache)
|
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache)
|
||||||
attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize)
|
attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize)
|
||||||
|
|
@ -123,8 +125,7 @@ func (m Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
ropeDim := cmp.Or(m.ropeDim, m.hiddenSize/m.numHeads)
|
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
|
||||||
return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithTypeNeoX()), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(c fs.Config) (model.Model, error) {
|
func New(c fs.Config) (model.Model, error) {
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,6 @@ import (
|
||||||
"github.com/ollama/ollama/kvcache"
|
"github.com/ollama/ollama/kvcache"
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
"github.com/ollama/ollama/ml/nn"
|
"github.com/ollama/ollama/ml/nn"
|
||||||
"github.com/ollama/ollama/ml/nn/fast"
|
|
||||||
"github.com/ollama/ollama/ml/nn/rope"
|
"github.com/ollama/ollama/ml/nn/rope"
|
||||||
"github.com/ollama/ollama/model/input"
|
"github.com/ollama/ollama/model/input"
|
||||||
)
|
)
|
||||||
|
|
@ -18,6 +17,13 @@ type TextOptions struct {
|
||||||
eps, ropeBase, ropeScale float32
|
eps, ropeBase, ropeScale float32
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (o TextOptions) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
|
||||||
|
return nn.RoPE(ctx, states, positions, o.ropeDim, o.ropeBase, 1./o.ropeScale,
|
||||||
|
rope.WithOriginalContextLength(o.originalContextLength),
|
||||||
|
rope.WithTypeNeoX(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
type TextModel struct {
|
type TextModel struct {
|
||||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||||
Layers []Layer `gguf:"blk"`
|
Layers []Layer `gguf:"blk"`
|
||||||
|
|
@ -60,11 +66,11 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||||
|
|
||||||
q := sa.Query.Forward(ctx, hiddenState)
|
q := sa.Query.Forward(ctx, hiddenState)
|
||||||
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||||
q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithOriginalContextLength(opts.originalContextLength), rope.WithTypeNeoX())
|
q = opts.applyRotaryPositionEmbeddings(ctx, q, positionIDs)
|
||||||
|
|
||||||
k := sa.Key.Forward(ctx, hiddenState)
|
k := sa.Key.Forward(ctx, hiddenState)
|
||||||
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithOriginalContextLength(opts.originalContextLength), rope.WithTypeNeoX())
|
k = opts.applyRotaryPositionEmbeddings(ctx, k, positionIDs)
|
||||||
|
|
||||||
v := sa.Value.Forward(ctx, hiddenState)
|
v := sa.Value.Forward(ctx, hiddenState)
|
||||||
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
|
|
@ -78,7 +84,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||||
|
|
||||||
// Shift applies rotary position embeddings to the key tensor for causal attention caching
|
// Shift applies rotary position embeddings to the key tensor for causal attention caching
|
||||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithOriginalContextLength(m.originalContextLength), rope.WithTypeNeoX()), nil
|
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// MLP implements the feed-forward network component with SwiGLU activation
|
// MLP implements the feed-forward network component with SwiGLU activation
|
||||||
|
|
|
||||||
|
|
@ -18,8 +18,8 @@ func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor {
|
||||||
return x2.Scale(ctx, -1).Concat(ctx, x1, 0)
|
return x2.Scale(ctx, -1).Concat(ctx, x1, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
func applyRotaryPositionalEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor {
|
func applyRotaryPositionEmbeddings(ctx ml.Context, states, cos, sin ml.Tensor) ml.Tensor {
|
||||||
return t.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, t).Mul(ctx, sin))
|
return states.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, states).Mul(ctx, sin))
|
||||||
}
|
}
|
||||||
|
|
||||||
func blockDiagonalMask(ctx ml.Context, seqLength int, bounds []int, numHeads int) ml.Tensor {
|
func blockDiagonalMask(ctx ml.Context, seqLength int, bounds []int, numHeads int) ml.Tensor {
|
||||||
|
|
@ -67,8 +67,8 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin, m
|
||||||
key = key.Reshape(ctx, opts.headDim, opts.numHeads, key.Dim(1), batchSize)
|
key = key.Reshape(ctx, opts.headDim, opts.numHeads, key.Dim(1), batchSize)
|
||||||
value = value.Reshape(ctx, opts.headDim, opts.numHeads, value.Dim(1), batchSize)
|
value = value.Reshape(ctx, opts.headDim, opts.numHeads, value.Dim(1), batchSize)
|
||||||
|
|
||||||
query = applyRotaryPositionalEmbedding(ctx, query, cos, sin)
|
query = applyRotaryPositionEmbeddings(ctx, query, cos, sin)
|
||||||
key = applyRotaryPositionalEmbedding(ctx, key, cos, sin)
|
key = applyRotaryPositionEmbeddings(ctx, key, cos, sin)
|
||||||
|
|
||||||
// Scale factor for scaled dot-product attention
|
// Scale factor for scaled dot-product attention
|
||||||
scale := 1.0 / math.Sqrt(float64(opts.headDim))
|
scale := 1.0 / math.Sqrt(float64(opts.headDim))
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,6 @@ import (
|
||||||
"github.com/ollama/ollama/kvcache"
|
"github.com/ollama/ollama/kvcache"
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
"github.com/ollama/ollama/ml/nn"
|
"github.com/ollama/ollama/ml/nn"
|
||||||
"github.com/ollama/ollama/ml/nn/fast"
|
|
||||||
"github.com/ollama/ollama/ml/nn/rope"
|
"github.com/ollama/ollama/ml/nn/rope"
|
||||||
"github.com/ollama/ollama/model"
|
"github.com/ollama/ollama/model"
|
||||||
"github.com/ollama/ollama/model/input"
|
"github.com/ollama/ollama/model/input"
|
||||||
|
|
@ -46,7 +45,7 @@ func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions
|
||||||
rope.WithAttentionFactor(attnFactor),
|
rope.WithAttentionFactor(attnFactor),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
return fast.RoPE(ctx, states, positions, o.headDim(), o.ropeBase, 1./o.ropeScale, opts...)
|
return nn.RoPE(ctx, states, positions, o.headDim(), o.ropeBase, 1./o.ropeScale, opts...)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Attention struct {
|
type Attention struct {
|
||||||
|
|
|
||||||
|
|
@ -195,7 +195,7 @@ func New(c fs.Config) (model.Model, error) {
|
||||||
m.Cache = kvcache.NewCausalCache(func(ctx ml.Context, layer int, key, positions ml.Tensor) (ml.Tensor, error) {
|
m.Cache = kvcache.NewCausalCache(func(ctx ml.Context, layer int, key, positions ml.Tensor) (ml.Tensor, error) {
|
||||||
m.positionCache = nil
|
m.positionCache = nil
|
||||||
positions = positions.Repeat(ctx, 1, 4).Reshape(ctx, -1)
|
positions = positions.Repeat(ctx, 1, 4).Reshape(ctx, -1)
|
||||||
return m.Options.applyRotaryPositionalEmbedding(ctx, key, positions), nil
|
return m.Options.applyRotaryPositionEmbeddings(ctx, key, positions), nil
|
||||||
})
|
})
|
||||||
return &m, nil
|
return &m, nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,6 @@ import (
|
||||||
"github.com/ollama/ollama/kvcache"
|
"github.com/ollama/ollama/kvcache"
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
"github.com/ollama/ollama/ml/nn"
|
"github.com/ollama/ollama/ml/nn"
|
||||||
"github.com/ollama/ollama/ml/nn/fast"
|
|
||||||
"github.com/ollama/ollama/ml/nn/rope"
|
"github.com/ollama/ollama/ml/nn/rope"
|
||||||
"github.com/ollama/ollama/model"
|
"github.com/ollama/ollama/model"
|
||||||
)
|
)
|
||||||
|
|
@ -35,8 +34,8 @@ func (o TextOptions) headDim() int {
|
||||||
return cmp.Or(o.keyLength, o.valueLength, o.hiddenSize/o.numHeads)
|
return cmp.Or(o.keyLength, o.valueLength, o.hiddenSize/o.numHeads)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o TextOptions) applyRotaryPositionalEmbedding(ctx ml.Context, t, p ml.Tensor) ml.Tensor {
|
func (o TextOptions) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
|
||||||
return fast.RoPE(ctx, t, p, o.headDim(), o.ropeBase, 1/float32(math.Sqrt(float64(o.ropeScale))),
|
return nn.RoPE(ctx, states, positions, o.headDim(), o.ropeBase, 1/float32(math.Sqrt(float64(o.ropeScale))),
|
||||||
rope.WithInterleaveMRoPE(o.mropeSections),
|
rope.WithInterleaveMRoPE(o.mropeSections),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
@ -64,8 +63,8 @@ func (sa *TextAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tens
|
||||||
query = sa.QueryNorm.Forward(ctx, query, opts.eps)
|
query = sa.QueryNorm.Forward(ctx, query, opts.eps)
|
||||||
key = sa.KeyNorm.Forward(ctx, key, opts.eps)
|
key = sa.KeyNorm.Forward(ctx, key, opts.eps)
|
||||||
|
|
||||||
query = opts.applyRotaryPositionalEmbedding(ctx, query, positions)
|
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
|
||||||
key = opts.applyRotaryPositionalEmbedding(ctx, key, positions)
|
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
|
||||||
|
|
||||||
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache)
|
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache)
|
||||||
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize)
|
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize)
|
||||||
|
|
|
||||||
|
|
@ -23,18 +23,18 @@ func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor {
|
||||||
return x2.Scale(ctx, -1).Concat(ctx, x1, 0)
|
return x2.Scale(ctx, -1).Concat(ctx, x1, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
func applyRotaryPositionalEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor {
|
func applyRotaryPositionEmbeddings(ctx ml.Context, states, cos, sin ml.Tensor) ml.Tensor {
|
||||||
return t.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, t).Mul(ctx, sin))
|
return states.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, states).Mul(ctx, sin))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sa *VisionAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, opts VisionOptions) ml.Tensor {
|
func (sa *VisionAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, opts VisionOptions) ml.Tensor {
|
||||||
query := sa.Query.Forward(ctx, hiddenStates)
|
query := sa.Query.Forward(ctx, hiddenStates)
|
||||||
query = query.Reshape(ctx, opts.headDim(), opts.numHeads, query.Dim(1))
|
query = query.Reshape(ctx, opts.headDim(), opts.numHeads, query.Dim(1))
|
||||||
query = applyRotaryPositionalEmbedding(ctx, query, cos, sin)
|
query = applyRotaryPositionEmbeddings(ctx, query, cos, sin)
|
||||||
|
|
||||||
key := sa.Key.Forward(ctx, hiddenStates)
|
key := sa.Key.Forward(ctx, hiddenStates)
|
||||||
key = key.Reshape(ctx, opts.headDim(), opts.numHeads, key.Dim(1))
|
key = key.Reshape(ctx, opts.headDim(), opts.numHeads, key.Dim(1))
|
||||||
key = applyRotaryPositionalEmbedding(ctx, key, cos, sin)
|
key = applyRotaryPositionEmbeddings(ctx, key, cos, sin)
|
||||||
|
|
||||||
value := sa.Value.Forward(ctx, hiddenStates)
|
value := sa.Value.Forward(ctx, hiddenStates)
|
||||||
value = value.Reshape(ctx, opts.headDim(), opts.numHeads, value.Dim(1))
|
value = value.Reshape(ctx, opts.headDim(), opts.numHeads, value.Dim(1))
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue