diff --git a/fs/ggml/ggml.go b/fs/ggml/ggml.go index 44a48511c..ea525df9d 100644 --- a/fs/ggml/ggml.go +++ b/fs/ggml/ggml.go @@ -827,10 +827,6 @@ func (f GGML) SupportsFlashAttention() bool { return false } - if arch := f.KV().Architecture(); slices.Contains([]string{"gemma2"}, arch) { - return false - } - // Check head counts match and are non-zero headCountK := f.KV().EmbeddingHeadCountK() headCountV := f.KV().EmbeddingHeadCountV() diff --git a/ml/backend.go b/ml/backend.go index 1e781fa7f..620f29d81 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -33,7 +33,7 @@ type Backend interface { // BackendCacheConfig should be implemented by backends that need special output // from the cache to meet specific requirements. It is frequently implemented in -// conjunction with ScaledDotProductAttention. +// conjunction with [nn.fastAttention]. type BackendCacheConfig interface { CacheConfig() CacheConfig } @@ -152,7 +152,6 @@ type Tensor interface { Div(ctx Context, t2 Tensor) Tensor Mulmat(ctx Context, t2 Tensor) Tensor - MulmatFullPrec(ctx Context, t2 Tensor) Tensor MulmatID(ctx Context, t2, ids Tensor) Tensor AddID(ctx Context, t2, ids Tensor) Tensor @@ -213,32 +212,6 @@ type Tensor interface { Interpolate(ctx Context, dims [4]int, samplingMode SamplingMode) Tensor } -// ScaledDotProductAttention implements a fused attention -// operation equivalent to following code on a tensor named -// query: -// -// query = query.Permute(ctx, 0, 2, 1, 3) -// key = key.Permute(ctx, 0, 2, 1, 3) -// value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) -// -// kq := key.MulmatFullPrec(ctx, query) -// -// kq = kq.Scale(ctx, scale) -// -// if mask != nil { -// kq = kq.Add(ctx, mask) -// } -// -// kq = kq.Softmax(ctx) -// -// kqv := value.Mulmat(ctx, kq) -// return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) -// -// cacheConfigApplied indicates whether the optimizations requested through CacheConfig have been performed -type ScaledDotProductAttention interface { - ScaledDotProductAttention(ctx Context, key, value, mask, sinks Tensor, vmla Tensor, scale float64, cacheConfigApplied bool) Tensor -} - type number interface { ~int | ~int8 | ~int16 | ~int32 | ~int64 | ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 6a044260a..98f8dfbfd 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -19,6 +19,7 @@ import ( "io" "log/slog" "maps" + "math" "os" "runtime" "slices" @@ -35,6 +36,7 @@ import ( "github.com/ollama/ollama/logutil" "github.com/ollama/ollama/ml" ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src" + "github.com/ollama/ollama/ml/nn/attention" "github.com/ollama/ollama/ml/nn/rope" "golang.org/x/sync/errgroup" ) @@ -882,7 +884,7 @@ func shapeToGGML(shape []int) *C.int64_t { return &sh[0] } -func pad(length, pad C.size_t) C.size_t { +func pad[T C.size_t | int](length, pad T) T { return ((length + pad - 1) / pad) * pad } @@ -1248,16 +1250,6 @@ func (t *Tensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor { } } -func (t *Tensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor { - mul := C.ggml_mul_mat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t) - C.ggml_mul_mat_set_prec(mul, C.GGML_PREC_F32) - - return &Tensor{ - b: t.b, - t: mul, - } -} - func (t *Tensor) MulmatID(ctx ml.Context, t2, ids ml.Tensor) ml.Tensor { return &Tensor{ b: t.b, @@ -1648,75 +1640,6 @@ func (t *Tensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor { } } -func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sinks ml.Tensor, vmla ml.Tensor, scale float64, cacheConfigApplied bool) ml.Tensor { - // If the cache didn't help us with required transformations, do them here - if !cacheConfigApplied { - cacheConfig := t.b.CacheConfig() - - // Padding key and value to CachePadding is a performance optimization, not a requirement, so we don't do it if it wasn't done by the caller - - if cacheConfig.PermutedV { - value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) - } - - if mask != nil { - padSize := int(pad(C.size_t(mask.Dim(1)), C.size_t(cacheConfig.MaskBatchPadding))) - mask.Dim(1) - if padSize > 0 { - mask = mask.Pad(ctx, 0, padSize, 0, 0) - } - - if mask.DType() != cacheConfig.MaskDType { - mask = mask.Cast(ctx, cacheConfig.MaskDType) - } - } - } - - var kqMask *C.struct_ggml_tensor - if mask != nil { - kqMask = mask.(*Tensor).t - } - - query := t.Permute(ctx, 0, 2, 1, 3) - key = key.Permute(ctx, 0, 2, 1, 3) - - if t.b.flashAttention == ml.FlashAttentionEnabled { - value = value.Permute(ctx, 0, 2, 1, 3) - - kqv := C.ggml_flash_attn_ext(ctx.(*Context).ctx, query.(*Tensor).t, key.(*Tensor).t, value.(*Tensor).t, kqMask, C.float(scale), 0, 0) - if sinks != nil { - C.ggml_flash_attn_ext_add_sinks(kqv, sinks.(*Tensor).t) - } - C.ggml_flash_attn_ext_set_prec(kqv, C.GGML_PREC_F32) - - if vmla != nil { - var cur ml.Tensor = &Tensor{b: t.b, t: kqv} - cur = cur.Permute(ctx, 0, 2, 1, 3) - cur = vmla.Mulmat(ctx, cur) - cur = cur.Permute(ctx, 0, 2, 1, 3) - cur = cur.Contiguous(ctx) - kqv = cur.(*Tensor).t - } - - return &Tensor{b: t.b, t: kqv} - } else { - kq := key.MulmatFullPrec(ctx, query) - kq = &Tensor{ - b: t.b, - t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0), - } - if sinks != nil { - C.ggml_soft_max_add_sinks(kq.(*Tensor).t, sinks.(*Tensor).t) - } - - kqv := value.Mulmat(ctx, kq) - if vmla != nil { - kqv = vmla.Mulmat(ctx, kqv) - } - - return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) - } -} - func (t *Tensor) Duplicate(ctx ml.Context) ml.Tensor { return &Tensor{ b: t.b, @@ -1849,3 +1772,89 @@ func (t *Tensor) ChunkSections(ctx ml.Context, dim int, sections ...int) []ml.Te } return s } + +func (t *Tensor) SDPA(ctx ml.Context, key, value ml.Tensor, fns ...func(*attention.Options)) ml.Tensor { + opts := attention.Options{ + Scale: 1 / math.Sqrt(float64(t.Dim(0))), + } + + for _, fn := range fns { + fn(&opts) + } + + if !opts.Cached { + config := t.b.CacheConfig() + if config.PermutedV { + value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) + } + + if opts.Mask != nil { + if padSize := pad(opts.Mask.Dim(1), config.MaskBatchPadding) - opts.Mask.Dim(1); padSize > 0 { + opts.Mask = opts.Mask.Pad(ctx, 0, padSize, 0, 0) + } + + if opts.Mask.DType() != config.MaskDType { + opts.Mask = opts.Mask.Cast(ctx, config.MaskDType) + } + } + } + + query := t.Permute(ctx, 0, 2, 1, 3) + key = key.Permute(ctx, 0, 2, 1, 3) + + var mask *C.struct_ggml_tensor + if opts.Mask != nil { + mask = opts.Mask.(*Tensor).t + } + + if t.b.flashAttention == ml.FlashAttentionEnabled { + value = value.Permute(ctx, 0, 2, 1, 3) + + tt := C.ggml_flash_attn_ext(ctx.(*Context).ctx, query.(*Tensor).t, key.(*Tensor).t, value.(*Tensor).t, mask, C.float(opts.Scale), 0, C.float(opts.LogitSoftcap)) + C.ggml_flash_attn_ext_set_prec(tt, C.GGML_PREC_F32) + if opts.Sinks != nil { + C.ggml_flash_attn_ext_add_sinks(tt, opts.Sinks.(*Tensor).t) + } + + var attention ml.Tensor = &Tensor{b: t.b, t: tt} + if opts.MLA != nil { + attention = attention.Permute(ctx, 0, 2, 1, 3) + attention = opts.MLA.Mulmat(ctx, attention) + attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) + } + + return attention + } + + scores := key.Mulmat(ctx, query) + C.ggml_mul_mat_set_prec(scores.(*Tensor).t, C.GGML_PREC_F32) + if opts.LogitSoftcap > 0 { + scores = scores.Scale(ctx, 1/float64(opts.LogitSoftcap)).Tanh(ctx).Scale(ctx, float64(opts.LogitSoftcap)) + } + + if opts.Cached { + scores = &Tensor{b: t.b, t: C.ggml_soft_max_ext(ctx.(*Context).ctx, scores.(*Tensor).t, mask, C.float(opts.Scale), 0)} + } else { + scores = scores.Scale(ctx, opts.Scale) + if opts.Mask != nil { + scores = scores.Add(ctx, opts.Mask) + } + + scores = scores.Softmax(ctx) + } + + if opts.Sinks != nil { + C.ggml_soft_max_add_sinks(scores.(*Tensor).t, opts.Sinks.(*Tensor).t) + } + + if key.Dim(1) == value.Dim(2) && key.Dim(2) == value.Dim(1) { + value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) + } + + attention := value.Mulmat(ctx, scores) + if opts.MLA != nil { + attention = opts.MLA.Mulmat(ctx, attention) + } + + return attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) +} diff --git a/ml/nn/attention.go b/ml/nn/attention.go index e495e1f60..5a3bfc092 100644 --- a/ml/nn/attention.go +++ b/ml/nn/attention.go @@ -1,12 +1,17 @@ package nn import ( - "fmt" + "log" "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn/attention" ) +type fastAttention interface { + SDPA(ctx ml.Context, key, value ml.Tensor, opts ...func(*attention.Options)) ml.Tensor +} + // Attention implements scaled dot-product attention for transformer models: // Attention(Q, K, V) = softmax(QK^T/√d_k)V // @@ -21,27 +26,19 @@ import ( // Returns: // // Attention output with shape [d_v, heads, seq_len_q] -func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor { - return AttentionWithVMLA(ctx, query, key, value, nil, nil, scale, cache) -} -func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor { - return AttentionWithVMLA(ctx, query, key, value, sinks, nil, scale, cache) -} - -func AttentionWithVMLA(ctx ml.Context, query, key, value, sinks ml.Tensor, vmla ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor { - ctx.Forward(query) +func Attention(ctx ml.Context, query, key, value ml.Tensor, cache kvcache.Cache, fns ...func(*attention.Options)) ml.Tensor { if key != nil && value != nil { if query.Dim(0) != key.Dim(0) { - panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0))) + log.Fatalf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)) } if key.Dim(1) != value.Dim(1) { - panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(1))) + log.Fatalf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(1)) } if key.Dim(2) != value.Dim(2) { - panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2))) + log.Fatalf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)) } ctx.Forward(key, value) @@ -57,28 +54,12 @@ func AttentionWithVMLA(ctx ml.Context, query, key, value, sinks ml.Tensor, vmla key, value, mask = cache.Get(ctx) } - if sdpa, ok := query.(ml.ScaledDotProductAttention); ok { - cacheConfigApplied := cache != nil - return sdpa.ScaledDotProductAttention(ctx, key, value, mask, sinks, vmla, scale, cacheConfigApplied) - } else { - query = query.Permute(ctx, 0, 2, 1, 3) - key = key.Permute(ctx, 0, 2, 1, 3) - value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) - - kq := key.MulmatFullPrec(ctx, query) - - kq = kq.Scale(ctx, scale) - if mask != nil { - kq = kq.Add(ctx, mask) - } - kq = kq.Softmax(ctx) - - kqv := value.Mulmat(ctx, kq) - - if vmla != nil { - kqv = vmla.Mulmat(ctx, kqv) - } - - return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) + if t, ok := query.(fastAttention); ok { + return t.SDPA(ctx, key, value, append([]func(*attention.Options){ + attention.WithMask(mask), + func(opts *attention.Options) { opts.Cached = cache != nil }, + }, fns...)...) } + + panic("Attention not implemented for this tensor type") } diff --git a/ml/nn/attention/options.go b/ml/nn/attention/options.go new file mode 100644 index 000000000..4bf1c377c --- /dev/null +++ b/ml/nn/attention/options.go @@ -0,0 +1,55 @@ +package attention + +import ( + "github.com/ollama/ollama/ml" +) + +type Options struct { + // Scale is a scaling factor applied to the attention scores. Default is 1/√d_k. + Scale float64 + + // LogitSoftcap is used to apply a soft cap to the logits before softmax. + LogitSoftcap float32 + + // Mask is used in some attention mechanisms to mask out certain positions. + Mask ml.Tensor + + // Sinks is used in some attention mechanisms to store additional data. + Sinks ml.Tensor + + // MLA is used in some attention mechanisms for multi-latent attention. + MLA ml.Tensor + + // Cached indicates whether key/value were retrieved from cache. + Cached bool +} + +func WithScale(scale float64) func(*Options) { + return func(o *Options) { + o.Scale = scale + } +} + +func WithSinks(sinks ml.Tensor) func(*Options) { + return func(o *Options) { + o.Sinks = sinks + } +} + +func WithMLA(mla ml.Tensor) func(*Options) { + return func(o *Options) { + o.MLA = mla + } +} + +func WithMask(mask ml.Tensor) func(*Options) { + return func(o *Options) { + o.Mask = mask + } +} + +func WithLogitSoftcap(softcap float32) func(*Options) { + return func(o *Options) { + o.LogitSoftcap = softcap + } +} diff --git a/model/models/bert/embed.go b/model/models/bert/embed.go index 79cb3a3c7..33c5d86c8 100644 --- a/model/models/bert/embed.go +++ b/model/models/bert/embed.go @@ -2,7 +2,6 @@ package bert import ( "cmp" - "math" "github.com/ollama/ollama/fs" "github.com/ollama/ollama/ml" @@ -99,7 +98,7 @@ func (a *Attention) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Option value := a.Value.Forward(ctx, hiddenStates) value = value.Reshape(ctx, opts.headDim(), cmp.Or(opts.numKVHeads, opts.numHeads), batchSize) - attention := nn.Attention(ctx, query, key, value, 1/math.Sqrt(float64(opts.headDim())), nil) + attention := nn.Attention(ctx, query, key, value, nil) attention = attention.Reshape(ctx, opts.hiddenSize, batchSize) return a.Output.Forward(ctx, attention) } diff --git a/model/models/deepseek2/model.go b/model/models/deepseek2/model.go index 576076aab..d40df204f 100644 --- a/model/models/deepseek2/model.go +++ b/model/models/deepseek2/model.go @@ -10,6 +10,7 @@ import ( "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/ml/nn/attention" "github.com/ollama/ollama/ml/nn/rope" "github.com/ollama/ollama/model" "github.com/ollama/ollama/model/input" @@ -66,22 +67,22 @@ type Attention struct { Output *nn.Linear `gguf:"attn_out,alt:attn_output"` } -func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { +func (m *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { seqLength := hiddenStates.Dim(1) var query ml.Tensor if opts.qLoraRank == 0 { - query = attn.Q.Forward(ctx, hiddenStates) + query = m.Q.Forward(ctx, hiddenStates) } else { - query = attn.QA.Forward(ctx, hiddenStates) - query = attn.QANorm.Forward(ctx, query, opts.eps) - query = attn.QB.Forward(ctx, query) + query = m.QA.Forward(ctx, hiddenStates) + query = m.QANorm.Forward(ctx, query, opts.eps) + query = m.QB.Forward(ctx, query) } query = query.Reshape(ctx, query.Dim(0)/opts.numHeads, opts.numHeads, seqLength) queryChunks := query.ChunkSections(ctx, 0, opts.qkNopeHeadDim, opts.qkRopeHeadDim) - compressedKV := attn.KVA.Forward(ctx, hiddenStates) + compressedKV := m.KVA.Forward(ctx, hiddenStates) kPass := compressedKV.Slice(ctx, 0, 0, opts.kvLoraRank, 1) kRot := compressedKV.View(ctx, opts.kvLoraRank*compressedKV.Stride(0), opts.qkRopeHeadDim, @@ -91,12 +92,10 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor qRot := opts.applyRotaryPositionEmbeddings(ctx, queryChunks[1], positions) kRot = opts.applyRotaryPositionEmbeddings(ctx, kRot, positions) - kPass = attn.KVANorm.Forward(ctx, kPass, opts.eps) - - var attention ml.Tensor + kPass = m.KVANorm.Forward(ctx, kPass, opts.eps) if !opts.isMLA { // v3 - kPass = attn.KVB.Forward(ctx, kPass) + kPass = m.KVB.Forward(ctx, kPass) kv := kPass.Reshape(ctx, kPass.Dim(0)/opts.numKVHeads, opts.numKVHeads, seqLength) kvChunks := kv.ChunkSections(ctx, 0, opts.kqNopeHeadDim, opts.vHeadDim) @@ -104,10 +103,10 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor kRot = kRot.Repeat(ctx, 1, queryChunks[0].Dim(1)) query = qRot.Concat(ctx, queryChunks[0], 0) key := kRot.Concat(ctx, kvChunks[0], 0) - attention = nn.Attention(ctx, query, key, kvChunks[1], opts.kqScale, cache) + hiddenStates = nn.Attention(ctx, query, key, kvChunks[1], cache, attention.WithScale(opts.kqScale)) } else { // v3.1 qPass := queryChunks[0].Permute(ctx, 0, 2, 1, 3) - qPassAbsorb := attn.KB.Forward(ctx, qPass) + qPassAbsorb := m.KB.Forward(ctx, qPass) qPassAbsorb = qPassAbsorb.Permute(ctx, 0, 2, 1, 3) query = qRot.Concat(ctx, qPassAbsorb, 0) @@ -115,11 +114,14 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor key := kRot.Concat(ctx, kPass, 0) value := kPass - attention = nn.AttentionWithVMLA(ctx, query, key, value, nil, attn.VB.Weight, opts.kqScale, cache) + hiddenStates = nn.Attention(ctx, query, key, value, cache, + attention.WithMLA(m.VB.Weight), + attention.WithScale(opts.kqScale), + ) } - attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), seqLength) - return attn.Output.Forward(ctx, attention) + hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0)*hiddenStates.Dim(1), seqLength) + return m.Output.Forward(ctx, hiddenStates) } type MLP interface { diff --git a/model/models/deepseekocr/model_sam.go b/model/models/deepseekocr/model_sam.go index 8bf30f96c..636743b72 100644 --- a/model/models/deepseekocr/model_sam.go +++ b/model/models/deepseekocr/model_sam.go @@ -1,11 +1,11 @@ package deepseekocr import ( - "math" "slices" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/ml/nn/attention" ) type samModel struct { @@ -166,23 +166,13 @@ func (m *samAttention) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts samO ctx.Forward(query, key, value) - query = query.Permute(ctx, 0, 2, 1, 3) - rh, rw := m.decomposedRelativePositions(ctx, query, []int{h, w}, []int{h, w}) + rh, rw := m.decomposedRelativePositions(ctx, query.Permute(ctx, 0, 2, 1, 3), []int{h, w}, []int{h, w}) mask := rh.Repeat(ctx, 0, rw.Dim(0)).Add(ctx, rw) mask = mask.Reshape(ctx, h*w, -1, opts.numHeads, b) - key = key.Permute(ctx, 0, 2, 1, 3) - scores := key.MulmatFullPrec(ctx, query) - scores = scores.Scale(ctx, 1/math.Sqrt(float64(opts.headDim()))) - - scores = scores.Add(ctx, mask) - scores = scores.Softmax(ctx) - - value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) - attention := value.Mulmat(ctx, scores) - attention = attention.Permute(ctx, 0, 2, 1, 3) - attention = attention.Contiguous(ctx, -1, w, h, b) - return m.Output.Forward(ctx, attention) + hiddenStates = nn.Attention(ctx, query, key, value, nil, attention.WithMask(mask)) + hiddenStates = hiddenStates.Contiguous(ctx, -1, w, h, b) + return m.Output.Forward(ctx, hiddenStates) } type samMLP struct { diff --git a/model/models/deepseekocr/model_text.go b/model/models/deepseekocr/model_text.go index ab6221ccf..f3aac476b 100644 --- a/model/models/deepseekocr/model_text.go +++ b/model/models/deepseekocr/model_text.go @@ -1,8 +1,6 @@ package deepseekocr import ( - "math" - "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" @@ -85,7 +83,7 @@ func (m *textAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tenso query = opts.applyRotaryPositionEmbeddings(ctx, query, positions) key = opts.applyRotaryPositionEmbeddings(ctx, key, positions) - attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache) + attention := nn.Attention(ctx, query, key, value, cache) attention = attention.Reshape(ctx, -1, attention.Dim(2)) return m.Output.Forward(ctx, attention) } diff --git a/model/models/deepseekocr/model_vision.go b/model/models/deepseekocr/model_vision.go index 61121ebfd..ac7380dae 100644 --- a/model/models/deepseekocr/model_vision.go +++ b/model/models/deepseekocr/model_vision.go @@ -102,7 +102,7 @@ func (m *visionAttention) Forward(ctx ml.Context, t ml.Tensor, opts visionOption chunks := qkv.Chunk(ctx, 1, opts.numHeads) query, key, value := chunks[0], chunks[1], chunks[2] - attention := nn.Attention(ctx, query, key, value, 1/math.Sqrt(float64(opts.headDim())), nil) + attention := nn.Attention(ctx, query, key, value, nil) attention = attention.Reshape(ctx, -1, attention.Dim(2), attention.Dim(3)) return m.Output.Forward(ctx, attention) } diff --git a/model/models/gemma2/model.go b/model/models/gemma2/model.go index 7b0aa2f01..83ba6c719 100644 --- a/model/models/gemma2/model.go +++ b/model/models/gemma2/model.go @@ -7,6 +7,7 @@ import ( "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/ml/nn/attention" "github.com/ollama/ollama/ml/nn/rope" "github.com/ollama/ollama/model" "github.com/ollama/ollama/model/input" @@ -72,9 +73,10 @@ func New(c fs.Config) (model.Model, error) { }, } - slidingWindowLen := int32(c.Uint("attention.sliding_window")) - m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift)) - m.Cache.SetConfig(ml.CacheConfig{}) + m.Cache = kvcache.NewWrapperCache( + kvcache.NewSWACache(int32(c.Uint("attention.sliding_window")), m.Shift), + kvcache.NewCausalCache(m.Shift), + ) return &m, nil } @@ -106,28 +108,13 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize) - cache.Put(ctx, k, v) - k, v, mask := cache.Get(ctx) + hiddenState = nn.Attention(ctx, q, k, v, cache, + attention.WithLogitSoftcap(opts.attnLogitSoftcap), + attention.WithScale(1), + ) - q = q.Permute(ctx, 0, 2, 1, 3) - k = k.Permute(ctx, 0, 2, 1, 3) - v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) - - kq := k.Mulmat(ctx, q) - - // logit softcap - kq = kq.Scale(ctx, 1.0/float64(opts.attnLogitSoftcap)) - kq = kq.Tanh(ctx) - kq = kq.Scale(ctx, float64(opts.attnLogitSoftcap)) - - kq = kq.Add(ctx, mask) - kq = kq.Softmax(ctx) - - kqv := v.Mulmat(ctx, kq) - kqv = kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) - kqv = kqv.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize) - - return sa.Output.Forward(ctx, kqv) + hiddenState = hiddenState.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize) + return sa.Output.Forward(ctx, hiddenState) } func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { diff --git a/model/models/gemma3/model_text.go b/model/models/gemma3/model_text.go index e1c0004d9..25b0a3d60 100644 --- a/model/models/gemma3/model_text.go +++ b/model/models/gemma3/model_text.go @@ -7,6 +7,7 @@ import ( "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/ml/nn/attention" "github.com/ollama/ollama/ml/nn/rope" "github.com/ollama/ollama/model/input" ) @@ -165,8 +166,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize) - scaleFactor := 1.0 - kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache) + kqv := nn.Attention(ctx, q, k, v, cache, attention.WithScale(1)) kqv = kqv.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize) return sa.Output.Forward(ctx, kqv) diff --git a/model/models/gemma3/model_vision.go b/model/models/gemma3/model_vision.go index 8b1a8eb00..e99def5e9 100644 --- a/model/models/gemma3/model_vision.go +++ b/model/models/gemma3/model_vision.go @@ -1,8 +1,6 @@ package gemma3 import ( - "math" - "github.com/ollama/ollama/fs" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" @@ -28,7 +26,7 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, op key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), batchSize) value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize) - attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), nil) + attention := nn.Attention(ctx, query, key, value, nil) attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize) hiddenState = sa.Output.Forward(ctx, attention) diff --git a/model/models/gemma3n/model_text.go b/model/models/gemma3n/model_text.go index 89cc54b8b..0e28fbc98 100644 --- a/model/models/gemma3n/model_text.go +++ b/model/models/gemma3n/model_text.go @@ -8,6 +8,7 @@ import ( "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/ml/nn/attention" "github.com/ollama/ollama/ml/nn/rope" "github.com/ollama/ollama/model/input" ) @@ -269,9 +270,9 @@ func (attn TextAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Ten value = value.RMSNorm(ctx, nil, opts.eps) } - attention := nn.Attention(ctx, query, key, value, 1., cache) - attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize) - return attn.Output.Forward(ctx, attention) + hiddenStates = nn.Attention(ctx, query, key, value, cache, attention.WithScale(1)) + hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0)*hiddenStates.Dim(1), batchSize) + return attn.Output.Forward(ctx, hiddenStates) } type TextMLP struct { diff --git a/model/models/gptoss/model.go b/model/models/gptoss/model.go index 9d1520bf3..3e061a412 100644 --- a/model/models/gptoss/model.go +++ b/model/models/gptoss/model.go @@ -2,13 +2,13 @@ package gptoss import ( "cmp" - "math" "strings" "github.com/ollama/ollama/fs" "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/ml/nn/attention" "github.com/ollama/ollama/ml/nn/rope" "github.com/ollama/ollama/model" "github.com/ollama/ollama/model/input" @@ -137,9 +137,9 @@ func (attn *AttentionBlock) Forward(ctx ml.Context, hiddenStates, positions ml.T query = opts.applyRotaryPositionEmbeddings(ctx, query, positions) key = opts.applyRotaryPositionEmbeddings(ctx, key, positions) - attention := nn.AttentionWithSinks(ctx, query, key, value, attn.Sinks, 1/math.Sqrt(float64(opts.headDim())), cache) - attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize) - return attn.Output.Forward(ctx, attention).Add(ctx, residual) + hiddenStates = nn.Attention(ctx, query, key, value, cache, attention.WithSinks(attn.Sinks)) + hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0)*hiddenStates.Dim(1), batchSize) + return attn.Output.Forward(ctx, hiddenStates).Add(ctx, residual) } type MLPBlock struct { diff --git a/model/models/llama/model.go b/model/models/llama/model.go index 5ff4894e4..023545213 100644 --- a/model/models/llama/model.go +++ b/model/models/llama/model.go @@ -2,7 +2,6 @@ package llama import ( "cmp" - "math" "github.com/ollama/ollama/fs" "github.com/ollama/ollama/kvcache" @@ -131,7 +130,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tenso query = opts.applyRotaryPositionEmbeddings(ctx, query, positions, sa.RopeFactors) key = opts.applyRotaryPositionEmbeddings(ctx, key, positions, sa.RopeFactors) - attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache) + attention := nn.Attention(ctx, query, key, value, cache) attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize) return sa.Output.Forward(ctx, attention) } diff --git a/model/models/llama4/model_text.go b/model/models/llama4/model_text.go index c2bf06148..692fd2396 100644 --- a/model/models/llama4/model_text.go +++ b/model/models/llama4/model_text.go @@ -45,7 +45,7 @@ func (sa *TextAttention) Forward(ctx ml.Context, hiddenStates, positions, attent query = query.Mul(ctx, attentionScales) } - attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(headDim)), cache) + attention := nn.Attention(ctx, query, key, value, cache) attention = attention.Reshape(ctx, opts.hiddenSize, batchSize) return sa.Output.Forward(ctx, attention) } diff --git a/model/models/llama4/model_vision.go b/model/models/llama4/model_vision.go index ff6b7fcf2..0186b5db1 100644 --- a/model/models/llama4/model_vision.go +++ b/model/models/llama4/model_vision.go @@ -72,7 +72,7 @@ func (sa *VisionAttention) Forward(ctx ml.Context, hiddenState, cos, sin ml.Tens query = applyVisionRotaryEmbedding(ctx, query, cos, sin) key = applyVisionRotaryEmbedding(ctx, key, cos, sin) - attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(headDim)), nil) + attention := nn.Attention(ctx, query, key, value, nil) attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), attention.Dim(3)) return sa.Output.Forward(ctx, attention) } diff --git a/model/models/mistral3/model_text.go b/model/models/mistral3/model_text.go index 36106107b..cb5b1c090 100644 --- a/model/models/mistral3/model_text.go +++ b/model/models/mistral3/model_text.go @@ -79,7 +79,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs, posit q = q.Mul(ctx, positionsScale) } - kqv := nn.Attention(ctx, q, k, v, 1.0/math.Sqrt(float64(headDim)), cache) + kqv := nn.Attention(ctx, q, k, v, cache) kqv = kqv.Reshape(ctx, headDim*opts.numHeads, batchSize) return sa.Output.Forward(ctx, kqv) } diff --git a/model/models/mistral3/model_vision.go b/model/models/mistral3/model_vision.go index 1de0412d5..8e291e9d0 100644 --- a/model/models/mistral3/model_vision.go +++ b/model/models/mistral3/model_vision.go @@ -39,7 +39,7 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml query = applyRotaryPositionEmbeddings(ctx, query, cos, sin) key = applyRotaryPositionEmbeddings(ctx, key, cos, sin) - attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim)), nil) + attention := nn.Attention(ctx, query, key, value, nil) attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize) return sa.Output.Forward(ctx, attention) } diff --git a/model/models/mllama/model_text.go b/model/models/mllama/model_text.go index afd674eb9..fd6e2fb3f 100644 --- a/model/models/mllama/model_text.go +++ b/model/models/mllama/model_text.go @@ -1,7 +1,6 @@ package mllama import ( - "math" "slices" "github.com/ollama/ollama/fs" @@ -34,8 +33,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.T value := sa.Value.Forward(ctx, hiddenState) value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - scaleFactor := 1.0 / math.Sqrt(float64(headDim)) - attention := nn.Attention(ctx, query, key, value, scaleFactor, cache) + attention := nn.Attention(ctx, query, key, value, cache) attention = attention.Reshape(ctx, opts.hiddenSize, batchSize) return sa.Output.Forward(ctx, attention) @@ -122,20 +120,7 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio } key, value, _ = cache.Get(ctx) - - scaleFactor := 1.0 / math.Sqrt(float64(headDim)) - - query = query.Permute(ctx, 0, 2, 1, 3) - key = key.Permute(ctx, 0, 2, 1, 3) - value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) - - kq := key.MulmatFullPrec(ctx, query) - - kq = kq.Scale(ctx, scaleFactor) - kq = kq.Softmax(ctx) - - kqv := value.Mulmat(ctx, kq) - attention := kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) + attention := nn.Attention(ctx, query, key, value, nil) attention = attention.Reshape(ctx, opts.hiddenSize, batchSize) return ca.Output.Forward(ctx, attention) diff --git a/model/models/mllama/model_vision.go b/model/models/mllama/model_vision.go index 2d4249472..2c48ae870 100644 --- a/model/models/mllama/model_vision.go +++ b/model/models/mllama/model_vision.go @@ -1,7 +1,6 @@ package mllama import ( - "math" "slices" "github.com/ollama/ollama/fs" @@ -30,7 +29,7 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, op value := sa.Value.Forward(ctx, hiddenState) value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize) - attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(headDim)), nil) + attention := nn.Attention(ctx, query, key, value, nil) attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize) return sa.Output.Forward(ctx, attention) } diff --git a/model/models/nomicbert/model.go b/model/models/nomicbert/model.go index 096d046a0..75c41de6f 100644 --- a/model/models/nomicbert/model.go +++ b/model/models/nomicbert/model.go @@ -2,7 +2,6 @@ package nomicbert import ( "cmp" - "math" "github.com/ollama/ollama/fs" "github.com/ollama/ollama/ml" @@ -166,7 +165,7 @@ func (a *Attention) Forward(ctx ml.Context, hiddenStates ml.Tensor, positions ml query = opts.applyRotaryPositionEmbeddings(ctx, query, positions) key = opts.applyRotaryPositionEmbeddings(ctx, key, positions) - attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(opts.headDim)), nil) + attention := nn.Attention(ctx, query, key, value, nil) attention = attention.Reshape(ctx, opts.hiddenSize, batchSize) diff --git a/model/models/olmo3/model.go b/model/models/olmo3/model.go index 523c00e68..ba1287279 100644 --- a/model/models/olmo3/model.go +++ b/model/models/olmo3/model.go @@ -2,7 +2,6 @@ package olmo3 import ( "fmt" - "math" "github.com/ollama/ollama/fs" "github.com/ollama/ollama/kvcache" @@ -132,7 +131,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tenso value := sa.Value.Forward(ctx, hiddenState) value = value.Reshape(ctx, headDim, m.numKVHeads, batchSize) - attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache) + attention := nn.Attention(ctx, query, key, value, cache) attention = attention.Reshape(ctx, m.hiddenSize, batchSize) return sa.Output.Forward(ctx, attention) diff --git a/model/models/qwen2/model.go b/model/models/qwen2/model.go index 66f546ae6..36d5e26f7 100644 --- a/model/models/qwen2/model.go +++ b/model/models/qwen2/model.go @@ -3,7 +3,6 @@ package qwen2 import ( "cmp" "fmt" - "math" "strings" "github.com/ollama/ollama/fs" @@ -48,7 +47,7 @@ func (attn Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, query = opts.applyRotaryPositionEmbeddings(ctx, query, positions) key = opts.applyRotaryPositionEmbeddings(ctx, key, positions) - attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache) + attention := nn.Attention(ctx, query, key, value, cache) attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize) return attn.Output.Forward(ctx, attention) diff --git a/model/models/qwen25vl/model_text.go b/model/models/qwen25vl/model_text.go index 61b072d67..6f209d2a1 100644 --- a/model/models/qwen25vl/model_text.go +++ b/model/models/qwen25vl/model_text.go @@ -1,8 +1,6 @@ package qwen25vl import ( - "math" - "github.com/ollama/ollama/fs" "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" @@ -81,8 +79,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - scaleFactor := 1.0 / math.Sqrt(float64(headDim)) - kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache) + kqv := nn.Attention(ctx, q, k, v, cache) kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize) return sa.Output.Forward(ctx, kqv) diff --git a/model/models/qwen25vl/model_vision.go b/model/models/qwen25vl/model_vision.go index f1275437f..b0f2eb54b 100644 --- a/model/models/qwen25vl/model_vision.go +++ b/model/models/qwen25vl/model_vision.go @@ -8,6 +8,7 @@ import ( "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" "github.com/ollama/ollama/ml/nn/rope" + "github.com/ollama/ollama/ml/nn/attention" ) func blockDiagonalMask(ctx ml.Context, seqLength int, bounds []int) ml.Tensor { @@ -50,25 +51,9 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, positions, query = opts.applyRotaryPositionEmbeddings(ctx, query, positions) key = opts.applyRotaryPositionEmbeddings(ctx, key, positions) - // Scale factor for scaled dot-product attention - scale := 1.0 / math.Sqrt(float64(opts.headDim)) - - // Scaled dot-product attention - query = query.Permute(ctx, 0, 2, 1, 3) - key = key.Permute(ctx, 0, 2, 1, 3) - value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) - - kq := key.MulmatFullPrec(ctx, query) - kq = kq.Scale(ctx, scale) - if mask != nil { - kq = kq.Add(ctx, mask) - } - kq = kq.Softmax(ctx) - kqv := value.Mulmat(ctx, kq) - attention := kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) - attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2)) - - return sa.Output.Forward(ctx, attention) + hiddenStates = nn.Attention(ctx, query, key, value, nil, attention.WithMask(mask)) + hiddenStates = hiddenStates.Reshape(ctx, opts.hiddenSize, hiddenStates.Dim(2)) + return sa.Output.Forward(ctx, hiddenStates) } // VisionMLP implements the multi-layer perceptron diff --git a/model/models/qwen3/model.go b/model/models/qwen3/model.go index d7747364e..c081aac7c 100644 --- a/model/models/qwen3/model.go +++ b/model/models/qwen3/model.go @@ -74,7 +74,7 @@ func (sa *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, query = opts.applyRotaryPositionEmbeddings(ctx, query, positions) key = opts.applyRotaryPositionEmbeddings(ctx, key, positions) - attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache) + attention := nn.Attention(ctx, query, key, value, cache) attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize) return sa.Output.Forward(ctx, attention) } diff --git a/model/models/qwen3vl/model_text.go b/model/models/qwen3vl/model_text.go index 750c2473a..8fc1f80bd 100644 --- a/model/models/qwen3vl/model_text.go +++ b/model/models/qwen3vl/model_text.go @@ -66,7 +66,7 @@ func (sa *TextAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tens query = opts.applyRotaryPositionEmbeddings(ctx, query, positions) key = opts.applyRotaryPositionEmbeddings(ctx, key, positions) - attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache) + attention := nn.Attention(ctx, query, key, value, cache) attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize) return sa.Output.Forward(ctx, attention) } diff --git a/model/models/qwen3vl/model_vision.go b/model/models/qwen3vl/model_vision.go index 761281edc..3ec2061b2 100644 --- a/model/models/qwen3vl/model_vision.go +++ b/model/models/qwen3vl/model_vision.go @@ -39,7 +39,7 @@ func (sa *VisionAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Ten value := sa.Value.Forward(ctx, hiddenStates) value = value.Reshape(ctx, opts.headDim(), opts.numHeads, value.Dim(1)) - attention := nn.Attention(ctx, query, key, value, math.Pow(float64(opts.headDim()), -0.5), nil) + attention := nn.Attention(ctx, query, key, value, nil) attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2)) return sa.Output.Forward(ctx, attention) }