This commit is contained in:
Michael Yang 2025-12-17 06:54:52 +02:00 committed by GitHub
commit 4005a3bb71
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
30 changed files with 228 additions and 277 deletions

View File

@ -827,10 +827,6 @@ func (f GGML) SupportsFlashAttention() bool {
return false
}
if arch := f.KV().Architecture(); slices.Contains([]string{"gemma2"}, arch) {
return false
}
// Check head counts match and are non-zero
headCountK := f.KV().EmbeddingHeadCountK()
headCountV := f.KV().EmbeddingHeadCountV()

View File

@ -33,7 +33,7 @@ type Backend interface {
// BackendCacheConfig should be implemented by backends that need special output
// from the cache to meet specific requirements. It is frequently implemented in
// conjunction with ScaledDotProductAttention.
// conjunction with [nn.fastAttention].
type BackendCacheConfig interface {
CacheConfig() CacheConfig
}
@ -152,7 +152,6 @@ type Tensor interface {
Div(ctx Context, t2 Tensor) Tensor
Mulmat(ctx Context, t2 Tensor) Tensor
MulmatFullPrec(ctx Context, t2 Tensor) Tensor
MulmatID(ctx Context, t2, ids Tensor) Tensor
AddID(ctx Context, t2, ids Tensor) Tensor
@ -213,32 +212,6 @@ type Tensor interface {
Interpolate(ctx Context, dims [4]int, samplingMode SamplingMode) Tensor
}
// ScaledDotProductAttention implements a fused attention
// operation equivalent to following code on a tensor named
// query:
//
// query = query.Permute(ctx, 0, 2, 1, 3)
// key = key.Permute(ctx, 0, 2, 1, 3)
// value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
//
// kq := key.MulmatFullPrec(ctx, query)
//
// kq = kq.Scale(ctx, scale)
//
// if mask != nil {
// kq = kq.Add(ctx, mask)
// }
//
// kq = kq.Softmax(ctx)
//
// kqv := value.Mulmat(ctx, kq)
// return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
//
// cacheConfigApplied indicates whether the optimizations requested through CacheConfig have been performed
type ScaledDotProductAttention interface {
ScaledDotProductAttention(ctx Context, key, value, mask, sinks Tensor, vmla Tensor, scale float64, cacheConfigApplied bool) Tensor
}
type number interface {
~int | ~int8 | ~int16 | ~int32 | ~int64 |
~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 |

View File

@ -19,6 +19,7 @@ import (
"io"
"log/slog"
"maps"
"math"
"os"
"runtime"
"slices"
@ -35,6 +36,7 @@ import (
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml"
ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src"
"github.com/ollama/ollama/ml/nn/attention"
"github.com/ollama/ollama/ml/nn/rope"
"golang.org/x/sync/errgroup"
)
@ -882,7 +884,7 @@ func shapeToGGML(shape []int) *C.int64_t {
return &sh[0]
}
func pad(length, pad C.size_t) C.size_t {
func pad[T C.size_t | int](length, pad T) T {
return ((length + pad - 1) / pad) * pad
}
@ -1248,16 +1250,6 @@ func (t *Tensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
}
}
func (t *Tensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
mul := C.ggml_mul_mat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t)
C.ggml_mul_mat_set_prec(mul, C.GGML_PREC_F32)
return &Tensor{
b: t.b,
t: mul,
}
}
func (t *Tensor) MulmatID(ctx ml.Context, t2, ids ml.Tensor) ml.Tensor {
return &Tensor{
b: t.b,
@ -1648,75 +1640,6 @@ func (t *Tensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
}
}
func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sinks ml.Tensor, vmla ml.Tensor, scale float64, cacheConfigApplied bool) ml.Tensor {
// If the cache didn't help us with required transformations, do them here
if !cacheConfigApplied {
cacheConfig := t.b.CacheConfig()
// Padding key and value to CachePadding is a performance optimization, not a requirement, so we don't do it if it wasn't done by the caller
if cacheConfig.PermutedV {
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
}
if mask != nil {
padSize := int(pad(C.size_t(mask.Dim(1)), C.size_t(cacheConfig.MaskBatchPadding))) - mask.Dim(1)
if padSize > 0 {
mask = mask.Pad(ctx, 0, padSize, 0, 0)
}
if mask.DType() != cacheConfig.MaskDType {
mask = mask.Cast(ctx, cacheConfig.MaskDType)
}
}
}
var kqMask *C.struct_ggml_tensor
if mask != nil {
kqMask = mask.(*Tensor).t
}
query := t.Permute(ctx, 0, 2, 1, 3)
key = key.Permute(ctx, 0, 2, 1, 3)
if t.b.flashAttention == ml.FlashAttentionEnabled {
value = value.Permute(ctx, 0, 2, 1, 3)
kqv := C.ggml_flash_attn_ext(ctx.(*Context).ctx, query.(*Tensor).t, key.(*Tensor).t, value.(*Tensor).t, kqMask, C.float(scale), 0, 0)
if sinks != nil {
C.ggml_flash_attn_ext_add_sinks(kqv, sinks.(*Tensor).t)
}
C.ggml_flash_attn_ext_set_prec(kqv, C.GGML_PREC_F32)
if vmla != nil {
var cur ml.Tensor = &Tensor{b: t.b, t: kqv}
cur = cur.Permute(ctx, 0, 2, 1, 3)
cur = vmla.Mulmat(ctx, cur)
cur = cur.Permute(ctx, 0, 2, 1, 3)
cur = cur.Contiguous(ctx)
kqv = cur.(*Tensor).t
}
return &Tensor{b: t.b, t: kqv}
} else {
kq := key.MulmatFullPrec(ctx, query)
kq = &Tensor{
b: t.b,
t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0),
}
if sinks != nil {
C.ggml_soft_max_add_sinks(kq.(*Tensor).t, sinks.(*Tensor).t)
}
kqv := value.Mulmat(ctx, kq)
if vmla != nil {
kqv = vmla.Mulmat(ctx, kqv)
}
return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
}
}
func (t *Tensor) Duplicate(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
@ -1849,3 +1772,89 @@ func (t *Tensor) ChunkSections(ctx ml.Context, dim int, sections ...int) []ml.Te
}
return s
}
func (t *Tensor) SDPA(ctx ml.Context, key, value ml.Tensor, fns ...func(*attention.Options)) ml.Tensor {
opts := attention.Options{
Scale: 1 / math.Sqrt(float64(t.Dim(0))),
}
for _, fn := range fns {
fn(&opts)
}
if !opts.Cached {
config := t.b.CacheConfig()
if config.PermutedV {
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
}
if opts.Mask != nil {
if padSize := pad(opts.Mask.Dim(1), config.MaskBatchPadding) - opts.Mask.Dim(1); padSize > 0 {
opts.Mask = opts.Mask.Pad(ctx, 0, padSize, 0, 0)
}
if opts.Mask.DType() != config.MaskDType {
opts.Mask = opts.Mask.Cast(ctx, config.MaskDType)
}
}
}
query := t.Permute(ctx, 0, 2, 1, 3)
key = key.Permute(ctx, 0, 2, 1, 3)
var mask *C.struct_ggml_tensor
if opts.Mask != nil {
mask = opts.Mask.(*Tensor).t
}
if t.b.flashAttention == ml.FlashAttentionEnabled {
value = value.Permute(ctx, 0, 2, 1, 3)
tt := C.ggml_flash_attn_ext(ctx.(*Context).ctx, query.(*Tensor).t, key.(*Tensor).t, value.(*Tensor).t, mask, C.float(opts.Scale), 0, C.float(opts.LogitSoftcap))
C.ggml_flash_attn_ext_set_prec(tt, C.GGML_PREC_F32)
if opts.Sinks != nil {
C.ggml_flash_attn_ext_add_sinks(tt, opts.Sinks.(*Tensor).t)
}
var attention ml.Tensor = &Tensor{b: t.b, t: tt}
if opts.MLA != nil {
attention = attention.Permute(ctx, 0, 2, 1, 3)
attention = opts.MLA.Mulmat(ctx, attention)
attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
}
return attention
}
scores := key.Mulmat(ctx, query)
C.ggml_mul_mat_set_prec(scores.(*Tensor).t, C.GGML_PREC_F32)
if opts.LogitSoftcap > 0 {
scores = scores.Scale(ctx, 1/float64(opts.LogitSoftcap)).Tanh(ctx).Scale(ctx, float64(opts.LogitSoftcap))
}
if opts.Cached {
scores = &Tensor{b: t.b, t: C.ggml_soft_max_ext(ctx.(*Context).ctx, scores.(*Tensor).t, mask, C.float(opts.Scale), 0)}
} else {
scores = scores.Scale(ctx, opts.Scale)
if opts.Mask != nil {
scores = scores.Add(ctx, opts.Mask)
}
scores = scores.Softmax(ctx)
}
if opts.Sinks != nil {
C.ggml_soft_max_add_sinks(scores.(*Tensor).t, opts.Sinks.(*Tensor).t)
}
if key.Dim(1) == value.Dim(2) && key.Dim(2) == value.Dim(1) {
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
}
attention := value.Mulmat(ctx, scores)
if opts.MLA != nil {
attention = opts.MLA.Mulmat(ctx, attention)
}
return attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
}

View File

@ -1,12 +1,17 @@
package nn
import (
"fmt"
"log"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn/attention"
)
type fastAttention interface {
SDPA(ctx ml.Context, key, value ml.Tensor, opts ...func(*attention.Options)) ml.Tensor
}
// Attention implements scaled dot-product attention for transformer models:
// Attention(Q, K, V) = softmax(QK^T/√d_k)V
//
@ -21,27 +26,19 @@ import (
// Returns:
//
// Attention output with shape [d_v, heads, seq_len_q]
func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
return AttentionWithVMLA(ctx, query, key, value, nil, nil, scale, cache)
}
func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
return AttentionWithVMLA(ctx, query, key, value, sinks, nil, scale, cache)
}
func AttentionWithVMLA(ctx ml.Context, query, key, value, sinks ml.Tensor, vmla ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
ctx.Forward(query)
func Attention(ctx ml.Context, query, key, value ml.Tensor, cache kvcache.Cache, fns ...func(*attention.Options)) ml.Tensor {
if key != nil && value != nil {
if query.Dim(0) != key.Dim(0) {
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
log.Fatalf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0))
}
if key.Dim(1) != value.Dim(1) {
panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(1)))
log.Fatalf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(1))
}
if key.Dim(2) != value.Dim(2) {
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
log.Fatalf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2))
}
ctx.Forward(key, value)
@ -57,28 +54,12 @@ func AttentionWithVMLA(ctx ml.Context, query, key, value, sinks ml.Tensor, vmla
key, value, mask = cache.Get(ctx)
}
if sdpa, ok := query.(ml.ScaledDotProductAttention); ok {
cacheConfigApplied := cache != nil
return sdpa.ScaledDotProductAttention(ctx, key, value, mask, sinks, vmla, scale, cacheConfigApplied)
} else {
query = query.Permute(ctx, 0, 2, 1, 3)
key = key.Permute(ctx, 0, 2, 1, 3)
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
kq := key.MulmatFullPrec(ctx, query)
kq = kq.Scale(ctx, scale)
if mask != nil {
kq = kq.Add(ctx, mask)
}
kq = kq.Softmax(ctx)
kqv := value.Mulmat(ctx, kq)
if vmla != nil {
kqv = vmla.Mulmat(ctx, kqv)
}
return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
if t, ok := query.(fastAttention); ok {
return t.SDPA(ctx, key, value, append([]func(*attention.Options){
attention.WithMask(mask),
func(opts *attention.Options) { opts.Cached = cache != nil },
}, fns...)...)
}
panic("Attention not implemented for this tensor type")
}

View File

@ -0,0 +1,55 @@
package attention
import (
"github.com/ollama/ollama/ml"
)
type Options struct {
// Scale is a scaling factor applied to the attention scores. Default is 1/√d_k.
Scale float64
// LogitSoftcap is used to apply a soft cap to the logits before softmax.
LogitSoftcap float32
// Mask is used in some attention mechanisms to mask out certain positions.
Mask ml.Tensor
// Sinks is used in some attention mechanisms to store additional data.
Sinks ml.Tensor
// MLA is used in some attention mechanisms for multi-latent attention.
MLA ml.Tensor
// Cached indicates whether key/value were retrieved from cache.
Cached bool
}
func WithScale(scale float64) func(*Options) {
return func(o *Options) {
o.Scale = scale
}
}
func WithSinks(sinks ml.Tensor) func(*Options) {
return func(o *Options) {
o.Sinks = sinks
}
}
func WithMLA(mla ml.Tensor) func(*Options) {
return func(o *Options) {
o.MLA = mla
}
}
func WithMask(mask ml.Tensor) func(*Options) {
return func(o *Options) {
o.Mask = mask
}
}
func WithLogitSoftcap(softcap float32) func(*Options) {
return func(o *Options) {
o.LogitSoftcap = softcap
}
}

View File

@ -2,7 +2,6 @@ package bert
import (
"cmp"
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
@ -99,7 +98,7 @@ func (a *Attention) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Option
value := a.Value.Forward(ctx, hiddenStates)
value = value.Reshape(ctx, opts.headDim(), cmp.Or(opts.numKVHeads, opts.numHeads), batchSize)
attention := nn.Attention(ctx, query, key, value, 1/math.Sqrt(float64(opts.headDim())), nil)
attention := nn.Attention(ctx, query, key, value, nil)
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
return a.Output.Forward(ctx, attention)
}

View File

@ -10,6 +10,7 @@ import (
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/attention"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
@ -66,22 +67,22 @@ type Attention struct {
Output *nn.Linear `gguf:"attn_out,alt:attn_output"`
}
func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
func (m *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
seqLength := hiddenStates.Dim(1)
var query ml.Tensor
if opts.qLoraRank == 0 {
query = attn.Q.Forward(ctx, hiddenStates)
query = m.Q.Forward(ctx, hiddenStates)
} else {
query = attn.QA.Forward(ctx, hiddenStates)
query = attn.QANorm.Forward(ctx, query, opts.eps)
query = attn.QB.Forward(ctx, query)
query = m.QA.Forward(ctx, hiddenStates)
query = m.QANorm.Forward(ctx, query, opts.eps)
query = m.QB.Forward(ctx, query)
}
query = query.Reshape(ctx, query.Dim(0)/opts.numHeads, opts.numHeads, seqLength)
queryChunks := query.ChunkSections(ctx, 0, opts.qkNopeHeadDim, opts.qkRopeHeadDim)
compressedKV := attn.KVA.Forward(ctx, hiddenStates)
compressedKV := m.KVA.Forward(ctx, hiddenStates)
kPass := compressedKV.Slice(ctx, 0, 0, opts.kvLoraRank, 1)
kRot := compressedKV.View(ctx,
opts.kvLoraRank*compressedKV.Stride(0), opts.qkRopeHeadDim,
@ -91,12 +92,10 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor
qRot := opts.applyRotaryPositionEmbeddings(ctx, queryChunks[1], positions)
kRot = opts.applyRotaryPositionEmbeddings(ctx, kRot, positions)
kPass = attn.KVANorm.Forward(ctx, kPass, opts.eps)
var attention ml.Tensor
kPass = m.KVANorm.Forward(ctx, kPass, opts.eps)
if !opts.isMLA { // v3
kPass = attn.KVB.Forward(ctx, kPass)
kPass = m.KVB.Forward(ctx, kPass)
kv := kPass.Reshape(ctx, kPass.Dim(0)/opts.numKVHeads, opts.numKVHeads, seqLength)
kvChunks := kv.ChunkSections(ctx, 0, opts.kqNopeHeadDim, opts.vHeadDim)
@ -104,10 +103,10 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor
kRot = kRot.Repeat(ctx, 1, queryChunks[0].Dim(1))
query = qRot.Concat(ctx, queryChunks[0], 0)
key := kRot.Concat(ctx, kvChunks[0], 0)
attention = nn.Attention(ctx, query, key, kvChunks[1], opts.kqScale, cache)
hiddenStates = nn.Attention(ctx, query, key, kvChunks[1], cache, attention.WithScale(opts.kqScale))
} else { // v3.1
qPass := queryChunks[0].Permute(ctx, 0, 2, 1, 3)
qPassAbsorb := attn.KB.Forward(ctx, qPass)
qPassAbsorb := m.KB.Forward(ctx, qPass)
qPassAbsorb = qPassAbsorb.Permute(ctx, 0, 2, 1, 3)
query = qRot.Concat(ctx, qPassAbsorb, 0)
@ -115,11 +114,14 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor
key := kRot.Concat(ctx, kPass, 0)
value := kPass
attention = nn.AttentionWithVMLA(ctx, query, key, value, nil, attn.VB.Weight, opts.kqScale, cache)
hiddenStates = nn.Attention(ctx, query, key, value, cache,
attention.WithMLA(m.VB.Weight),
attention.WithScale(opts.kqScale),
)
}
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), seqLength)
return attn.Output.Forward(ctx, attention)
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0)*hiddenStates.Dim(1), seqLength)
return m.Output.Forward(ctx, hiddenStates)
}
type MLP interface {

View File

@ -1,11 +1,11 @@
package deepseekocr
import (
"math"
"slices"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/attention"
)
type samModel struct {
@ -166,23 +166,13 @@ func (m *samAttention) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts samO
ctx.Forward(query, key, value)
query = query.Permute(ctx, 0, 2, 1, 3)
rh, rw := m.decomposedRelativePositions(ctx, query, []int{h, w}, []int{h, w})
rh, rw := m.decomposedRelativePositions(ctx, query.Permute(ctx, 0, 2, 1, 3), []int{h, w}, []int{h, w})
mask := rh.Repeat(ctx, 0, rw.Dim(0)).Add(ctx, rw)
mask = mask.Reshape(ctx, h*w, -1, opts.numHeads, b)
key = key.Permute(ctx, 0, 2, 1, 3)
scores := key.MulmatFullPrec(ctx, query)
scores = scores.Scale(ctx, 1/math.Sqrt(float64(opts.headDim())))
scores = scores.Add(ctx, mask)
scores = scores.Softmax(ctx)
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
attention := value.Mulmat(ctx, scores)
attention = attention.Permute(ctx, 0, 2, 1, 3)
attention = attention.Contiguous(ctx, -1, w, h, b)
return m.Output.Forward(ctx, attention)
hiddenStates = nn.Attention(ctx, query, key, value, nil, attention.WithMask(mask))
hiddenStates = hiddenStates.Contiguous(ctx, -1, w, h, b)
return m.Output.Forward(ctx, hiddenStates)
}
type samMLP struct {

View File

@ -1,8 +1,6 @@
package deepseekocr
import (
"math"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
@ -85,7 +83,7 @@ func (m *textAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tenso
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache)
attention := nn.Attention(ctx, query, key, value, cache)
attention = attention.Reshape(ctx, -1, attention.Dim(2))
return m.Output.Forward(ctx, attention)
}

View File

@ -102,7 +102,7 @@ func (m *visionAttention) Forward(ctx ml.Context, t ml.Tensor, opts visionOption
chunks := qkv.Chunk(ctx, 1, opts.numHeads)
query, key, value := chunks[0], chunks[1], chunks[2]
attention := nn.Attention(ctx, query, key, value, 1/math.Sqrt(float64(opts.headDim())), nil)
attention := nn.Attention(ctx, query, key, value, nil)
attention = attention.Reshape(ctx, -1, attention.Dim(2), attention.Dim(3))
return m.Output.Forward(ctx, attention)
}

View File

@ -7,6 +7,7 @@ import (
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/attention"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
@ -72,9 +73,10 @@ func New(c fs.Config) (model.Model, error) {
},
}
slidingWindowLen := int32(c.Uint("attention.sliding_window"))
m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift))
m.Cache.SetConfig(ml.CacheConfig{})
m.Cache = kvcache.NewWrapperCache(
kvcache.NewSWACache(int32(c.Uint("attention.sliding_window")), m.Shift),
kvcache.NewCausalCache(m.Shift),
)
return &m, nil
}
@ -106,28 +108,13 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
cache.Put(ctx, k, v)
k, v, mask := cache.Get(ctx)
hiddenState = nn.Attention(ctx, q, k, v, cache,
attention.WithLogitSoftcap(opts.attnLogitSoftcap),
attention.WithScale(1),
)
q = q.Permute(ctx, 0, 2, 1, 3)
k = k.Permute(ctx, 0, 2, 1, 3)
v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
kq := k.Mulmat(ctx, q)
// logit softcap
kq = kq.Scale(ctx, 1.0/float64(opts.attnLogitSoftcap))
kq = kq.Tanh(ctx)
kq = kq.Scale(ctx, float64(opts.attnLogitSoftcap))
kq = kq.Add(ctx, mask)
kq = kq.Softmax(ctx)
kqv := v.Mulmat(ctx, kq)
kqv = kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
kqv = kqv.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize)
return sa.Output.Forward(ctx, kqv)
hiddenState = hiddenState.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize)
return sa.Output.Forward(ctx, hiddenState)
}
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {

View File

@ -7,6 +7,7 @@ import (
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/attention"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model/input"
)
@ -165,8 +166,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
scaleFactor := 1.0
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
kqv := nn.Attention(ctx, q, k, v, cache, attention.WithScale(1))
kqv = kqv.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize)
return sa.Output.Forward(ctx, kqv)

View File

@ -1,8 +1,6 @@
package gemma3
import (
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
@ -28,7 +26,7 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, op
key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), batchSize)
value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize)
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), nil)
attention := nn.Attention(ctx, query, key, value, nil)
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
hiddenState = sa.Output.Forward(ctx, attention)

View File

@ -8,6 +8,7 @@ import (
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/attention"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model/input"
)
@ -269,9 +270,9 @@ func (attn TextAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Ten
value = value.RMSNorm(ctx, nil, opts.eps)
}
attention := nn.Attention(ctx, query, key, value, 1., cache)
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize)
return attn.Output.Forward(ctx, attention)
hiddenStates = nn.Attention(ctx, query, key, value, cache, attention.WithScale(1))
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0)*hiddenStates.Dim(1), batchSize)
return attn.Output.Forward(ctx, hiddenStates)
}
type TextMLP struct {

View File

@ -2,13 +2,13 @@ package gptoss
import (
"cmp"
"math"
"strings"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/attention"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
@ -137,9 +137,9 @@ func (attn *AttentionBlock) Forward(ctx ml.Context, hiddenStates, positions ml.T
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
attention := nn.AttentionWithSinks(ctx, query, key, value, attn.Sinks, 1/math.Sqrt(float64(opts.headDim())), cache)
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize)
return attn.Output.Forward(ctx, attention).Add(ctx, residual)
hiddenStates = nn.Attention(ctx, query, key, value, cache, attention.WithSinks(attn.Sinks))
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0)*hiddenStates.Dim(1), batchSize)
return attn.Output.Forward(ctx, hiddenStates).Add(ctx, residual)
}
type MLPBlock struct {

View File

@ -2,7 +2,6 @@ package llama
import (
"cmp"
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
@ -131,7 +130,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tenso
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions, sa.RopeFactors)
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions, sa.RopeFactors)
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache)
attention := nn.Attention(ctx, query, key, value, cache)
attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize)
return sa.Output.Forward(ctx, attention)
}

View File

@ -45,7 +45,7 @@ func (sa *TextAttention) Forward(ctx ml.Context, hiddenStates, positions, attent
query = query.Mul(ctx, attentionScales)
}
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(headDim)), cache)
attention := nn.Attention(ctx, query, key, value, cache)
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
return sa.Output.Forward(ctx, attention)
}

View File

@ -72,7 +72,7 @@ func (sa *VisionAttention) Forward(ctx ml.Context, hiddenState, cos, sin ml.Tens
query = applyVisionRotaryEmbedding(ctx, query, cos, sin)
key = applyVisionRotaryEmbedding(ctx, key, cos, sin)
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(headDim)), nil)
attention := nn.Attention(ctx, query, key, value, nil)
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), attention.Dim(3))
return sa.Output.Forward(ctx, attention)
}

View File

@ -79,7 +79,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs, posit
q = q.Mul(ctx, positionsScale)
}
kqv := nn.Attention(ctx, q, k, v, 1.0/math.Sqrt(float64(headDim)), cache)
kqv := nn.Attention(ctx, q, k, v, cache)
kqv = kqv.Reshape(ctx, headDim*opts.numHeads, batchSize)
return sa.Output.Forward(ctx, kqv)
}

View File

@ -39,7 +39,7 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml
query = applyRotaryPositionEmbeddings(ctx, query, cos, sin)
key = applyRotaryPositionEmbeddings(ctx, key, cos, sin)
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim)), nil)
attention := nn.Attention(ctx, query, key, value, nil)
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
return sa.Output.Forward(ctx, attention)
}

View File

@ -1,7 +1,6 @@
package mllama
import (
"math"
"slices"
"github.com/ollama/ollama/fs"
@ -34,8 +33,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.T
value := sa.Value.Forward(ctx, hiddenState)
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
attention := nn.Attention(ctx, query, key, value, scaleFactor, cache)
attention := nn.Attention(ctx, query, key, value, cache)
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
return sa.Output.Forward(ctx, attention)
@ -122,20 +120,7 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio
}
key, value, _ = cache.Get(ctx)
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
query = query.Permute(ctx, 0, 2, 1, 3)
key = key.Permute(ctx, 0, 2, 1, 3)
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
kq := key.MulmatFullPrec(ctx, query)
kq = kq.Scale(ctx, scaleFactor)
kq = kq.Softmax(ctx)
kqv := value.Mulmat(ctx, kq)
attention := kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
attention := nn.Attention(ctx, query, key, value, nil)
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
return ca.Output.Forward(ctx, attention)

View File

@ -1,7 +1,6 @@
package mllama
import (
"math"
"slices"
"github.com/ollama/ollama/fs"
@ -30,7 +29,7 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, op
value := sa.Value.Forward(ctx, hiddenState)
value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize)
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(headDim)), nil)
attention := nn.Attention(ctx, query, key, value, nil)
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
return sa.Output.Forward(ctx, attention)
}

View File

@ -2,7 +2,6 @@ package nomicbert
import (
"cmp"
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
@ -166,7 +165,7 @@ func (a *Attention) Forward(ctx ml.Context, hiddenStates ml.Tensor, positions ml
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(opts.headDim)), nil)
attention := nn.Attention(ctx, query, key, value, nil)
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)

View File

@ -2,7 +2,6 @@ package olmo3
import (
"fmt"
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
@ -132,7 +131,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tenso
value := sa.Value.Forward(ctx, hiddenState)
value = value.Reshape(ctx, headDim, m.numKVHeads, batchSize)
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache)
attention := nn.Attention(ctx, query, key, value, cache)
attention = attention.Reshape(ctx, m.hiddenSize, batchSize)
return sa.Output.Forward(ctx, attention)

View File

@ -3,7 +3,6 @@ package qwen2
import (
"cmp"
"fmt"
"math"
"strings"
"github.com/ollama/ollama/fs"
@ -48,7 +47,7 @@ func (attn Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor,
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache)
attention := nn.Attention(ctx, query, key, value, cache)
attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize)
return attn.Output.Forward(ctx, attention)

View File

@ -1,8 +1,6 @@
package qwen25vl
import (
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
@ -81,8 +79,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
kqv := nn.Attention(ctx, q, k, v, cache)
kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize)
return sa.Output.Forward(ctx, kqv)

View File

@ -8,6 +8,7 @@ import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/ml/nn/attention"
)
func blockDiagonalMask(ctx ml.Context, seqLength int, bounds []int) ml.Tensor {
@ -50,25 +51,9 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, positions,
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
// Scale factor for scaled dot-product attention
scale := 1.0 / math.Sqrt(float64(opts.headDim))
// Scaled dot-product attention
query = query.Permute(ctx, 0, 2, 1, 3)
key = key.Permute(ctx, 0, 2, 1, 3)
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
kq := key.MulmatFullPrec(ctx, query)
kq = kq.Scale(ctx, scale)
if mask != nil {
kq = kq.Add(ctx, mask)
}
kq = kq.Softmax(ctx)
kqv := value.Mulmat(ctx, kq)
attention := kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2))
return sa.Output.Forward(ctx, attention)
hiddenStates = nn.Attention(ctx, query, key, value, nil, attention.WithMask(mask))
hiddenStates = hiddenStates.Reshape(ctx, opts.hiddenSize, hiddenStates.Dim(2))
return sa.Output.Forward(ctx, hiddenStates)
}
// VisionMLP implements the multi-layer perceptron

View File

@ -74,7 +74,7 @@ func (sa *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor,
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache)
attention := nn.Attention(ctx, query, key, value, cache)
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize)
return sa.Output.Forward(ctx, attention)
}

View File

@ -66,7 +66,7 @@ func (sa *TextAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tens
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache)
attention := nn.Attention(ctx, query, key, value, cache)
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize)
return sa.Output.Forward(ctx, attention)
}

View File

@ -39,7 +39,7 @@ func (sa *VisionAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Ten
value := sa.Value.Forward(ctx, hiddenStates)
value = value.Reshape(ctx, opts.headDim(), opts.numHeads, value.Dim(1))
attention := nn.Attention(ctx, query, key, value, math.Pow(float64(opts.headDim()), -0.5), nil)
attention := nn.Attention(ctx, query, key, value, nil)
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2))
return sa.Output.Forward(ctx, attention)
}