mirror of https://github.com/ollama/ollama
Merge 89637ae43b into a013693f80
This commit is contained in:
commit
4005a3bb71
|
|
@ -827,10 +827,6 @@ func (f GGML) SupportsFlashAttention() bool {
|
|||
return false
|
||||
}
|
||||
|
||||
if arch := f.KV().Architecture(); slices.Contains([]string{"gemma2"}, arch) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check head counts match and are non-zero
|
||||
headCountK := f.KV().EmbeddingHeadCountK()
|
||||
headCountV := f.KV().EmbeddingHeadCountV()
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ type Backend interface {
|
|||
|
||||
// BackendCacheConfig should be implemented by backends that need special output
|
||||
// from the cache to meet specific requirements. It is frequently implemented in
|
||||
// conjunction with ScaledDotProductAttention.
|
||||
// conjunction with [nn.fastAttention].
|
||||
type BackendCacheConfig interface {
|
||||
CacheConfig() CacheConfig
|
||||
}
|
||||
|
|
@ -152,7 +152,6 @@ type Tensor interface {
|
|||
Div(ctx Context, t2 Tensor) Tensor
|
||||
|
||||
Mulmat(ctx Context, t2 Tensor) Tensor
|
||||
MulmatFullPrec(ctx Context, t2 Tensor) Tensor
|
||||
MulmatID(ctx Context, t2, ids Tensor) Tensor
|
||||
AddID(ctx Context, t2, ids Tensor) Tensor
|
||||
|
||||
|
|
@ -213,32 +212,6 @@ type Tensor interface {
|
|||
Interpolate(ctx Context, dims [4]int, samplingMode SamplingMode) Tensor
|
||||
}
|
||||
|
||||
// ScaledDotProductAttention implements a fused attention
|
||||
// operation equivalent to following code on a tensor named
|
||||
// query:
|
||||
//
|
||||
// query = query.Permute(ctx, 0, 2, 1, 3)
|
||||
// key = key.Permute(ctx, 0, 2, 1, 3)
|
||||
// value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
//
|
||||
// kq := key.MulmatFullPrec(ctx, query)
|
||||
//
|
||||
// kq = kq.Scale(ctx, scale)
|
||||
//
|
||||
// if mask != nil {
|
||||
// kq = kq.Add(ctx, mask)
|
||||
// }
|
||||
//
|
||||
// kq = kq.Softmax(ctx)
|
||||
//
|
||||
// kqv := value.Mulmat(ctx, kq)
|
||||
// return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
//
|
||||
// cacheConfigApplied indicates whether the optimizations requested through CacheConfig have been performed
|
||||
type ScaledDotProductAttention interface {
|
||||
ScaledDotProductAttention(ctx Context, key, value, mask, sinks Tensor, vmla Tensor, scale float64, cacheConfigApplied bool) Tensor
|
||||
}
|
||||
|
||||
type number interface {
|
||||
~int | ~int8 | ~int16 | ~int32 | ~int64 |
|
||||
~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 |
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ import (
|
|||
"io"
|
||||
"log/slog"
|
||||
"maps"
|
||||
"math"
|
||||
"os"
|
||||
"runtime"
|
||||
"slices"
|
||||
|
|
@ -35,6 +36,7 @@ import (
|
|||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/ml"
|
||||
ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src"
|
||||
"github.com/ollama/ollama/ml/nn/attention"
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
|
@ -882,7 +884,7 @@ func shapeToGGML(shape []int) *C.int64_t {
|
|||
return &sh[0]
|
||||
}
|
||||
|
||||
func pad(length, pad C.size_t) C.size_t {
|
||||
func pad[T C.size_t | int](length, pad T) T {
|
||||
return ((length + pad - 1) / pad) * pad
|
||||
}
|
||||
|
||||
|
|
@ -1248,16 +1250,6 @@ func (t *Tensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
|||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
mul := C.ggml_mul_mat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t)
|
||||
C.ggml_mul_mat_set_prec(mul, C.GGML_PREC_F32)
|
||||
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: mul,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) MulmatID(ctx ml.Context, t2, ids ml.Tensor) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
|
|
@ -1648,75 +1640,6 @@ func (t *Tensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
|
|||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sinks ml.Tensor, vmla ml.Tensor, scale float64, cacheConfigApplied bool) ml.Tensor {
|
||||
// If the cache didn't help us with required transformations, do them here
|
||||
if !cacheConfigApplied {
|
||||
cacheConfig := t.b.CacheConfig()
|
||||
|
||||
// Padding key and value to CachePadding is a performance optimization, not a requirement, so we don't do it if it wasn't done by the caller
|
||||
|
||||
if cacheConfig.PermutedV {
|
||||
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
}
|
||||
|
||||
if mask != nil {
|
||||
padSize := int(pad(C.size_t(mask.Dim(1)), C.size_t(cacheConfig.MaskBatchPadding))) - mask.Dim(1)
|
||||
if padSize > 0 {
|
||||
mask = mask.Pad(ctx, 0, padSize, 0, 0)
|
||||
}
|
||||
|
||||
if mask.DType() != cacheConfig.MaskDType {
|
||||
mask = mask.Cast(ctx, cacheConfig.MaskDType)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var kqMask *C.struct_ggml_tensor
|
||||
if mask != nil {
|
||||
kqMask = mask.(*Tensor).t
|
||||
}
|
||||
|
||||
query := t.Permute(ctx, 0, 2, 1, 3)
|
||||
key = key.Permute(ctx, 0, 2, 1, 3)
|
||||
|
||||
if t.b.flashAttention == ml.FlashAttentionEnabled {
|
||||
value = value.Permute(ctx, 0, 2, 1, 3)
|
||||
|
||||
kqv := C.ggml_flash_attn_ext(ctx.(*Context).ctx, query.(*Tensor).t, key.(*Tensor).t, value.(*Tensor).t, kqMask, C.float(scale), 0, 0)
|
||||
if sinks != nil {
|
||||
C.ggml_flash_attn_ext_add_sinks(kqv, sinks.(*Tensor).t)
|
||||
}
|
||||
C.ggml_flash_attn_ext_set_prec(kqv, C.GGML_PREC_F32)
|
||||
|
||||
if vmla != nil {
|
||||
var cur ml.Tensor = &Tensor{b: t.b, t: kqv}
|
||||
cur = cur.Permute(ctx, 0, 2, 1, 3)
|
||||
cur = vmla.Mulmat(ctx, cur)
|
||||
cur = cur.Permute(ctx, 0, 2, 1, 3)
|
||||
cur = cur.Contiguous(ctx)
|
||||
kqv = cur.(*Tensor).t
|
||||
}
|
||||
|
||||
return &Tensor{b: t.b, t: kqv}
|
||||
} else {
|
||||
kq := key.MulmatFullPrec(ctx, query)
|
||||
kq = &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0),
|
||||
}
|
||||
if sinks != nil {
|
||||
C.ggml_soft_max_add_sinks(kq.(*Tensor).t, sinks.(*Tensor).t)
|
||||
}
|
||||
|
||||
kqv := value.Mulmat(ctx, kq)
|
||||
if vmla != nil {
|
||||
kqv = vmla.Mulmat(ctx, kqv)
|
||||
}
|
||||
|
||||
return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Duplicate(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
|
|
@ -1849,3 +1772,89 @@ func (t *Tensor) ChunkSections(ctx ml.Context, dim int, sections ...int) []ml.Te
|
|||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (t *Tensor) SDPA(ctx ml.Context, key, value ml.Tensor, fns ...func(*attention.Options)) ml.Tensor {
|
||||
opts := attention.Options{
|
||||
Scale: 1 / math.Sqrt(float64(t.Dim(0))),
|
||||
}
|
||||
|
||||
for _, fn := range fns {
|
||||
fn(&opts)
|
||||
}
|
||||
|
||||
if !opts.Cached {
|
||||
config := t.b.CacheConfig()
|
||||
if config.PermutedV {
|
||||
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
}
|
||||
|
||||
if opts.Mask != nil {
|
||||
if padSize := pad(opts.Mask.Dim(1), config.MaskBatchPadding) - opts.Mask.Dim(1); padSize > 0 {
|
||||
opts.Mask = opts.Mask.Pad(ctx, 0, padSize, 0, 0)
|
||||
}
|
||||
|
||||
if opts.Mask.DType() != config.MaskDType {
|
||||
opts.Mask = opts.Mask.Cast(ctx, config.MaskDType)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
query := t.Permute(ctx, 0, 2, 1, 3)
|
||||
key = key.Permute(ctx, 0, 2, 1, 3)
|
||||
|
||||
var mask *C.struct_ggml_tensor
|
||||
if opts.Mask != nil {
|
||||
mask = opts.Mask.(*Tensor).t
|
||||
}
|
||||
|
||||
if t.b.flashAttention == ml.FlashAttentionEnabled {
|
||||
value = value.Permute(ctx, 0, 2, 1, 3)
|
||||
|
||||
tt := C.ggml_flash_attn_ext(ctx.(*Context).ctx, query.(*Tensor).t, key.(*Tensor).t, value.(*Tensor).t, mask, C.float(opts.Scale), 0, C.float(opts.LogitSoftcap))
|
||||
C.ggml_flash_attn_ext_set_prec(tt, C.GGML_PREC_F32)
|
||||
if opts.Sinks != nil {
|
||||
C.ggml_flash_attn_ext_add_sinks(tt, opts.Sinks.(*Tensor).t)
|
||||
}
|
||||
|
||||
var attention ml.Tensor = &Tensor{b: t.b, t: tt}
|
||||
if opts.MLA != nil {
|
||||
attention = attention.Permute(ctx, 0, 2, 1, 3)
|
||||
attention = opts.MLA.Mulmat(ctx, attention)
|
||||
attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
}
|
||||
|
||||
return attention
|
||||
}
|
||||
|
||||
scores := key.Mulmat(ctx, query)
|
||||
C.ggml_mul_mat_set_prec(scores.(*Tensor).t, C.GGML_PREC_F32)
|
||||
if opts.LogitSoftcap > 0 {
|
||||
scores = scores.Scale(ctx, 1/float64(opts.LogitSoftcap)).Tanh(ctx).Scale(ctx, float64(opts.LogitSoftcap))
|
||||
}
|
||||
|
||||
if opts.Cached {
|
||||
scores = &Tensor{b: t.b, t: C.ggml_soft_max_ext(ctx.(*Context).ctx, scores.(*Tensor).t, mask, C.float(opts.Scale), 0)}
|
||||
} else {
|
||||
scores = scores.Scale(ctx, opts.Scale)
|
||||
if opts.Mask != nil {
|
||||
scores = scores.Add(ctx, opts.Mask)
|
||||
}
|
||||
|
||||
scores = scores.Softmax(ctx)
|
||||
}
|
||||
|
||||
if opts.Sinks != nil {
|
||||
C.ggml_soft_max_add_sinks(scores.(*Tensor).t, opts.Sinks.(*Tensor).t)
|
||||
}
|
||||
|
||||
if key.Dim(1) == value.Dim(2) && key.Dim(2) == value.Dim(1) {
|
||||
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
}
|
||||
|
||||
attention := value.Mulmat(ctx, scores)
|
||||
if opts.MLA != nil {
|
||||
attention = opts.MLA.Mulmat(ctx, attention)
|
||||
}
|
||||
|
||||
return attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,12 +1,17 @@
|
|||
package nn
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn/attention"
|
||||
)
|
||||
|
||||
type fastAttention interface {
|
||||
SDPA(ctx ml.Context, key, value ml.Tensor, opts ...func(*attention.Options)) ml.Tensor
|
||||
}
|
||||
|
||||
// Attention implements scaled dot-product attention for transformer models:
|
||||
// Attention(Q, K, V) = softmax(QK^T/√d_k)V
|
||||
//
|
||||
|
|
@ -21,27 +26,19 @@ import (
|
|||
// Returns:
|
||||
//
|
||||
// Attention output with shape [d_v, heads, seq_len_q]
|
||||
func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
|
||||
return AttentionWithVMLA(ctx, query, key, value, nil, nil, scale, cache)
|
||||
}
|
||||
|
||||
func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
|
||||
return AttentionWithVMLA(ctx, query, key, value, sinks, nil, scale, cache)
|
||||
}
|
||||
|
||||
func AttentionWithVMLA(ctx ml.Context, query, key, value, sinks ml.Tensor, vmla ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
|
||||
ctx.Forward(query)
|
||||
func Attention(ctx ml.Context, query, key, value ml.Tensor, cache kvcache.Cache, fns ...func(*attention.Options)) ml.Tensor {
|
||||
if key != nil && value != nil {
|
||||
if query.Dim(0) != key.Dim(0) {
|
||||
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
|
||||
log.Fatalf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0))
|
||||
}
|
||||
|
||||
if key.Dim(1) != value.Dim(1) {
|
||||
panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(1)))
|
||||
log.Fatalf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(1))
|
||||
}
|
||||
|
||||
if key.Dim(2) != value.Dim(2) {
|
||||
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
|
||||
log.Fatalf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2))
|
||||
}
|
||||
|
||||
ctx.Forward(key, value)
|
||||
|
|
@ -57,28 +54,12 @@ func AttentionWithVMLA(ctx ml.Context, query, key, value, sinks ml.Tensor, vmla
|
|||
key, value, mask = cache.Get(ctx)
|
||||
}
|
||||
|
||||
if sdpa, ok := query.(ml.ScaledDotProductAttention); ok {
|
||||
cacheConfigApplied := cache != nil
|
||||
return sdpa.ScaledDotProductAttention(ctx, key, value, mask, sinks, vmla, scale, cacheConfigApplied)
|
||||
} else {
|
||||
query = query.Permute(ctx, 0, 2, 1, 3)
|
||||
key = key.Permute(ctx, 0, 2, 1, 3)
|
||||
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
|
||||
kq := key.MulmatFullPrec(ctx, query)
|
||||
|
||||
kq = kq.Scale(ctx, scale)
|
||||
if mask != nil {
|
||||
kq = kq.Add(ctx, mask)
|
||||
}
|
||||
kq = kq.Softmax(ctx)
|
||||
|
||||
kqv := value.Mulmat(ctx, kq)
|
||||
|
||||
if vmla != nil {
|
||||
kqv = vmla.Mulmat(ctx, kqv)
|
||||
}
|
||||
|
||||
return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
if t, ok := query.(fastAttention); ok {
|
||||
return t.SDPA(ctx, key, value, append([]func(*attention.Options){
|
||||
attention.WithMask(mask),
|
||||
func(opts *attention.Options) { opts.Cached = cache != nil },
|
||||
}, fns...)...)
|
||||
}
|
||||
|
||||
panic("Attention not implemented for this tensor type")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,55 @@
|
|||
package attention
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
// Scale is a scaling factor applied to the attention scores. Default is 1/√d_k.
|
||||
Scale float64
|
||||
|
||||
// LogitSoftcap is used to apply a soft cap to the logits before softmax.
|
||||
LogitSoftcap float32
|
||||
|
||||
// Mask is used in some attention mechanisms to mask out certain positions.
|
||||
Mask ml.Tensor
|
||||
|
||||
// Sinks is used in some attention mechanisms to store additional data.
|
||||
Sinks ml.Tensor
|
||||
|
||||
// MLA is used in some attention mechanisms for multi-latent attention.
|
||||
MLA ml.Tensor
|
||||
|
||||
// Cached indicates whether key/value were retrieved from cache.
|
||||
Cached bool
|
||||
}
|
||||
|
||||
func WithScale(scale float64) func(*Options) {
|
||||
return func(o *Options) {
|
||||
o.Scale = scale
|
||||
}
|
||||
}
|
||||
|
||||
func WithSinks(sinks ml.Tensor) func(*Options) {
|
||||
return func(o *Options) {
|
||||
o.Sinks = sinks
|
||||
}
|
||||
}
|
||||
|
||||
func WithMLA(mla ml.Tensor) func(*Options) {
|
||||
return func(o *Options) {
|
||||
o.MLA = mla
|
||||
}
|
||||
}
|
||||
|
||||
func WithMask(mask ml.Tensor) func(*Options) {
|
||||
return func(o *Options) {
|
||||
o.Mask = mask
|
||||
}
|
||||
}
|
||||
|
||||
func WithLogitSoftcap(softcap float32) func(*Options) {
|
||||
return func(o *Options) {
|
||||
o.LogitSoftcap = softcap
|
||||
}
|
||||
}
|
||||
|
|
@ -2,7 +2,6 @@ package bert
|
|||
|
||||
import (
|
||||
"cmp"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/ml"
|
||||
|
|
@ -99,7 +98,7 @@ func (a *Attention) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Option
|
|||
value := a.Value.Forward(ctx, hiddenStates)
|
||||
value = value.Reshape(ctx, opts.headDim(), cmp.Or(opts.numKVHeads, opts.numHeads), batchSize)
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1/math.Sqrt(float64(opts.headDim())), nil)
|
||||
attention := nn.Attention(ctx, query, key, value, nil)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
|
||||
return a.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import (
|
|||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/attention"
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
|
|
@ -66,22 +67,22 @@ type Attention struct {
|
|||
Output *nn.Linear `gguf:"attn_out,alt:attn_output"`
|
||||
}
|
||||
|
||||
func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||
func (m *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||
seqLength := hiddenStates.Dim(1)
|
||||
|
||||
var query ml.Tensor
|
||||
if opts.qLoraRank == 0 {
|
||||
query = attn.Q.Forward(ctx, hiddenStates)
|
||||
query = m.Q.Forward(ctx, hiddenStates)
|
||||
} else {
|
||||
query = attn.QA.Forward(ctx, hiddenStates)
|
||||
query = attn.QANorm.Forward(ctx, query, opts.eps)
|
||||
query = attn.QB.Forward(ctx, query)
|
||||
query = m.QA.Forward(ctx, hiddenStates)
|
||||
query = m.QANorm.Forward(ctx, query, opts.eps)
|
||||
query = m.QB.Forward(ctx, query)
|
||||
}
|
||||
|
||||
query = query.Reshape(ctx, query.Dim(0)/opts.numHeads, opts.numHeads, seqLength)
|
||||
queryChunks := query.ChunkSections(ctx, 0, opts.qkNopeHeadDim, opts.qkRopeHeadDim)
|
||||
|
||||
compressedKV := attn.KVA.Forward(ctx, hiddenStates)
|
||||
compressedKV := m.KVA.Forward(ctx, hiddenStates)
|
||||
kPass := compressedKV.Slice(ctx, 0, 0, opts.kvLoraRank, 1)
|
||||
kRot := compressedKV.View(ctx,
|
||||
opts.kvLoraRank*compressedKV.Stride(0), opts.qkRopeHeadDim,
|
||||
|
|
@ -91,12 +92,10 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor
|
|||
|
||||
qRot := opts.applyRotaryPositionEmbeddings(ctx, queryChunks[1], positions)
|
||||
kRot = opts.applyRotaryPositionEmbeddings(ctx, kRot, positions)
|
||||
kPass = attn.KVANorm.Forward(ctx, kPass, opts.eps)
|
||||
|
||||
var attention ml.Tensor
|
||||
kPass = m.KVANorm.Forward(ctx, kPass, opts.eps)
|
||||
|
||||
if !opts.isMLA { // v3
|
||||
kPass = attn.KVB.Forward(ctx, kPass)
|
||||
kPass = m.KVB.Forward(ctx, kPass)
|
||||
|
||||
kv := kPass.Reshape(ctx, kPass.Dim(0)/opts.numKVHeads, opts.numKVHeads, seqLength)
|
||||
kvChunks := kv.ChunkSections(ctx, 0, opts.kqNopeHeadDim, opts.vHeadDim)
|
||||
|
|
@ -104,10 +103,10 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor
|
|||
kRot = kRot.Repeat(ctx, 1, queryChunks[0].Dim(1))
|
||||
query = qRot.Concat(ctx, queryChunks[0], 0)
|
||||
key := kRot.Concat(ctx, kvChunks[0], 0)
|
||||
attention = nn.Attention(ctx, query, key, kvChunks[1], opts.kqScale, cache)
|
||||
hiddenStates = nn.Attention(ctx, query, key, kvChunks[1], cache, attention.WithScale(opts.kqScale))
|
||||
} else { // v3.1
|
||||
qPass := queryChunks[0].Permute(ctx, 0, 2, 1, 3)
|
||||
qPassAbsorb := attn.KB.Forward(ctx, qPass)
|
||||
qPassAbsorb := m.KB.Forward(ctx, qPass)
|
||||
qPassAbsorb = qPassAbsorb.Permute(ctx, 0, 2, 1, 3)
|
||||
|
||||
query = qRot.Concat(ctx, qPassAbsorb, 0)
|
||||
|
|
@ -115,11 +114,14 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor
|
|||
key := kRot.Concat(ctx, kPass, 0)
|
||||
value := kPass
|
||||
|
||||
attention = nn.AttentionWithVMLA(ctx, query, key, value, nil, attn.VB.Weight, opts.kqScale, cache)
|
||||
hiddenStates = nn.Attention(ctx, query, key, value, cache,
|
||||
attention.WithMLA(m.VB.Weight),
|
||||
attention.WithScale(opts.kqScale),
|
||||
)
|
||||
}
|
||||
|
||||
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), seqLength)
|
||||
return attn.Output.Forward(ctx, attention)
|
||||
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0)*hiddenStates.Dim(1), seqLength)
|
||||
return m.Output.Forward(ctx, hiddenStates)
|
||||
}
|
||||
|
||||
type MLP interface {
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
package deepseekocr
|
||||
|
||||
import (
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/attention"
|
||||
)
|
||||
|
||||
type samModel struct {
|
||||
|
|
@ -166,23 +166,13 @@ func (m *samAttention) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts samO
|
|||
|
||||
ctx.Forward(query, key, value)
|
||||
|
||||
query = query.Permute(ctx, 0, 2, 1, 3)
|
||||
rh, rw := m.decomposedRelativePositions(ctx, query, []int{h, w}, []int{h, w})
|
||||
rh, rw := m.decomposedRelativePositions(ctx, query.Permute(ctx, 0, 2, 1, 3), []int{h, w}, []int{h, w})
|
||||
mask := rh.Repeat(ctx, 0, rw.Dim(0)).Add(ctx, rw)
|
||||
mask = mask.Reshape(ctx, h*w, -1, opts.numHeads, b)
|
||||
|
||||
key = key.Permute(ctx, 0, 2, 1, 3)
|
||||
scores := key.MulmatFullPrec(ctx, query)
|
||||
scores = scores.Scale(ctx, 1/math.Sqrt(float64(opts.headDim())))
|
||||
|
||||
scores = scores.Add(ctx, mask)
|
||||
scores = scores.Softmax(ctx)
|
||||
|
||||
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
attention := value.Mulmat(ctx, scores)
|
||||
attention = attention.Permute(ctx, 0, 2, 1, 3)
|
||||
attention = attention.Contiguous(ctx, -1, w, h, b)
|
||||
return m.Output.Forward(ctx, attention)
|
||||
hiddenStates = nn.Attention(ctx, query, key, value, nil, attention.WithMask(mask))
|
||||
hiddenStates = hiddenStates.Contiguous(ctx, -1, w, h, b)
|
||||
return m.Output.Forward(ctx, hiddenStates)
|
||||
}
|
||||
|
||||
type samMLP struct {
|
||||
|
|
|
|||
|
|
@ -1,8 +1,6 @@
|
|||
package deepseekocr
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
|
|
@ -85,7 +83,7 @@ func (m *textAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tenso
|
|||
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
|
||||
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache)
|
||||
attention := nn.Attention(ctx, query, key, value, cache)
|
||||
attention = attention.Reshape(ctx, -1, attention.Dim(2))
|
||||
return m.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -102,7 +102,7 @@ func (m *visionAttention) Forward(ctx ml.Context, t ml.Tensor, opts visionOption
|
|||
chunks := qkv.Chunk(ctx, 1, opts.numHeads)
|
||||
query, key, value := chunks[0], chunks[1], chunks[2]
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1/math.Sqrt(float64(opts.headDim())), nil)
|
||||
attention := nn.Attention(ctx, query, key, value, nil)
|
||||
attention = attention.Reshape(ctx, -1, attention.Dim(2), attention.Dim(3))
|
||||
return m.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import (
|
|||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/attention"
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
|
|
@ -72,9 +73,10 @@ func New(c fs.Config) (model.Model, error) {
|
|||
},
|
||||
}
|
||||
|
||||
slidingWindowLen := int32(c.Uint("attention.sliding_window"))
|
||||
m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift))
|
||||
m.Cache.SetConfig(ml.CacheConfig{})
|
||||
m.Cache = kvcache.NewWrapperCache(
|
||||
kvcache.NewSWACache(int32(c.Uint("attention.sliding_window")), m.Shift),
|
||||
kvcache.NewCausalCache(m.Shift),
|
||||
)
|
||||
|
||||
return &m, nil
|
||||
}
|
||||
|
|
@ -106,28 +108,13 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
|||
v := sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
|
||||
|
||||
cache.Put(ctx, k, v)
|
||||
k, v, mask := cache.Get(ctx)
|
||||
hiddenState = nn.Attention(ctx, q, k, v, cache,
|
||||
attention.WithLogitSoftcap(opts.attnLogitSoftcap),
|
||||
attention.WithScale(1),
|
||||
)
|
||||
|
||||
q = q.Permute(ctx, 0, 2, 1, 3)
|
||||
k = k.Permute(ctx, 0, 2, 1, 3)
|
||||
v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
|
||||
kq := k.Mulmat(ctx, q)
|
||||
|
||||
// logit softcap
|
||||
kq = kq.Scale(ctx, 1.0/float64(opts.attnLogitSoftcap))
|
||||
kq = kq.Tanh(ctx)
|
||||
kq = kq.Scale(ctx, float64(opts.attnLogitSoftcap))
|
||||
|
||||
kq = kq.Add(ctx, mask)
|
||||
kq = kq.Softmax(ctx)
|
||||
|
||||
kqv := v.Mulmat(ctx, kq)
|
||||
kqv = kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
kqv = kqv.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize)
|
||||
|
||||
return sa.Output.Forward(ctx, kqv)
|
||||
hiddenState = hiddenState.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize)
|
||||
return sa.Output.Forward(ctx, hiddenState)
|
||||
}
|
||||
|
||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import (
|
|||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/attention"
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
|
@ -165,8 +166,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
|
|||
v := sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
|
||||
|
||||
scaleFactor := 1.0
|
||||
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
|
||||
kqv := nn.Attention(ctx, q, k, v, cache, attention.WithScale(1))
|
||||
kqv = kqv.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize)
|
||||
|
||||
return sa.Output.Forward(ctx, kqv)
|
||||
|
|
|
|||
|
|
@ -1,8 +1,6 @@
|
|||
package gemma3
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
|
|
@ -28,7 +26,7 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, op
|
|||
key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), batchSize)
|
||||
value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize)
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), nil)
|
||||
attention := nn.Attention(ctx, query, key, value, nil)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
|
||||
|
||||
hiddenState = sa.Output.Forward(ctx, attention)
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import (
|
|||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/attention"
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
|
@ -269,9 +270,9 @@ func (attn TextAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Ten
|
|||
value = value.RMSNorm(ctx, nil, opts.eps)
|
||||
}
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1., cache)
|
||||
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize)
|
||||
return attn.Output.Forward(ctx, attention)
|
||||
hiddenStates = nn.Attention(ctx, query, key, value, cache, attention.WithScale(1))
|
||||
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0)*hiddenStates.Dim(1), batchSize)
|
||||
return attn.Output.Forward(ctx, hiddenStates)
|
||||
}
|
||||
|
||||
type TextMLP struct {
|
||||
|
|
|
|||
|
|
@ -2,13 +2,13 @@ package gptoss
|
|||
|
||||
import (
|
||||
"cmp"
|
||||
"math"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/attention"
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
|
|
@ -137,9 +137,9 @@ func (attn *AttentionBlock) Forward(ctx ml.Context, hiddenStates, positions ml.T
|
|||
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
|
||||
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
|
||||
|
||||
attention := nn.AttentionWithSinks(ctx, query, key, value, attn.Sinks, 1/math.Sqrt(float64(opts.headDim())), cache)
|
||||
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize)
|
||||
return attn.Output.Forward(ctx, attention).Add(ctx, residual)
|
||||
hiddenStates = nn.Attention(ctx, query, key, value, cache, attention.WithSinks(attn.Sinks))
|
||||
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0)*hiddenStates.Dim(1), batchSize)
|
||||
return attn.Output.Forward(ctx, hiddenStates).Add(ctx, residual)
|
||||
}
|
||||
|
||||
type MLPBlock struct {
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ package llama
|
|||
|
||||
import (
|
||||
"cmp"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
|
|
@ -131,7 +130,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tenso
|
|||
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions, sa.RopeFactors)
|
||||
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions, sa.RopeFactors)
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache)
|
||||
attention := nn.Attention(ctx, query, key, value, cache)
|
||||
attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize)
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ func (sa *TextAttention) Forward(ctx ml.Context, hiddenStates, positions, attent
|
|||
query = query.Mul(ctx, attentionScales)
|
||||
}
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(headDim)), cache)
|
||||
attention := nn.Attention(ctx, query, key, value, cache)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -72,7 +72,7 @@ func (sa *VisionAttention) Forward(ctx ml.Context, hiddenState, cos, sin ml.Tens
|
|||
query = applyVisionRotaryEmbedding(ctx, query, cos, sin)
|
||||
key = applyVisionRotaryEmbedding(ctx, key, cos, sin)
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(headDim)), nil)
|
||||
attention := nn.Attention(ctx, query, key, value, nil)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), attention.Dim(3))
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -79,7 +79,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs, posit
|
|||
q = q.Mul(ctx, positionsScale)
|
||||
}
|
||||
|
||||
kqv := nn.Attention(ctx, q, k, v, 1.0/math.Sqrt(float64(headDim)), cache)
|
||||
kqv := nn.Attention(ctx, q, k, v, cache)
|
||||
kqv = kqv.Reshape(ctx, headDim*opts.numHeads, batchSize)
|
||||
return sa.Output.Forward(ctx, kqv)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml
|
|||
query = applyRotaryPositionEmbeddings(ctx, query, cos, sin)
|
||||
key = applyRotaryPositionEmbeddings(ctx, key, cos, sin)
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim)), nil)
|
||||
attention := nn.Attention(ctx, query, key, value, nil)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
package mllama
|
||||
|
||||
import (
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
|
|
@ -34,8 +33,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.T
|
|||
value := sa.Value.Forward(ctx, hiddenState)
|
||||
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
|
||||
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
|
||||
attention := nn.Attention(ctx, query, key, value, scaleFactor, cache)
|
||||
attention := nn.Attention(ctx, query, key, value, cache)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
|
||||
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
|
|
@ -122,20 +120,7 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio
|
|||
}
|
||||
|
||||
key, value, _ = cache.Get(ctx)
|
||||
|
||||
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
|
||||
|
||||
query = query.Permute(ctx, 0, 2, 1, 3)
|
||||
key = key.Permute(ctx, 0, 2, 1, 3)
|
||||
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
|
||||
kq := key.MulmatFullPrec(ctx, query)
|
||||
|
||||
kq = kq.Scale(ctx, scaleFactor)
|
||||
kq = kq.Softmax(ctx)
|
||||
|
||||
kqv := value.Mulmat(ctx, kq)
|
||||
attention := kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
attention := nn.Attention(ctx, query, key, value, nil)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
|
||||
|
||||
return ca.Output.Forward(ctx, attention)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
package mllama
|
||||
|
||||
import (
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
|
|
@ -30,7 +29,7 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, op
|
|||
value := sa.Value.Forward(ctx, hiddenState)
|
||||
value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize)
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(headDim)), nil)
|
||||
attention := nn.Attention(ctx, query, key, value, nil)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ package nomicbert
|
|||
|
||||
import (
|
||||
"cmp"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/ml"
|
||||
|
|
@ -166,7 +165,7 @@ func (a *Attention) Forward(ctx ml.Context, hiddenStates ml.Tensor, positions ml
|
|||
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
|
||||
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(opts.headDim)), nil)
|
||||
attention := nn.Attention(ctx, query, key, value, nil)
|
||||
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ package olmo3
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
|
|
@ -132,7 +131,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tenso
|
|||
value := sa.Value.Forward(ctx, hiddenState)
|
||||
value = value.Reshape(ctx, headDim, m.numKVHeads, batchSize)
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache)
|
||||
attention := nn.Attention(ctx, query, key, value, cache)
|
||||
attention = attention.Reshape(ctx, m.hiddenSize, batchSize)
|
||||
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ package qwen2
|
|||
import (
|
||||
"cmp"
|
||||
"fmt"
|
||||
"math"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
|
|
@ -48,7 +47,7 @@ func (attn Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor,
|
|||
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
|
||||
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache)
|
||||
attention := nn.Attention(ctx, query, key, value, cache)
|
||||
attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize)
|
||||
|
||||
return attn.Output.Forward(ctx, attention)
|
||||
|
|
|
|||
|
|
@ -1,8 +1,6 @@
|
|||
package qwen25vl
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
|
|
@ -81,8 +79,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
|||
v := sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
|
||||
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
|
||||
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
|
||||
kqv := nn.Attention(ctx, q, k, v, cache)
|
||||
kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize)
|
||||
|
||||
return sa.Output.Forward(ctx, kqv)
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import (
|
|||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/ml/nn/attention"
|
||||
)
|
||||
|
||||
func blockDiagonalMask(ctx ml.Context, seqLength int, bounds []int) ml.Tensor {
|
||||
|
|
@ -50,25 +51,9 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, positions,
|
|||
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
|
||||
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
|
||||
|
||||
// Scale factor for scaled dot-product attention
|
||||
scale := 1.0 / math.Sqrt(float64(opts.headDim))
|
||||
|
||||
// Scaled dot-product attention
|
||||
query = query.Permute(ctx, 0, 2, 1, 3)
|
||||
key = key.Permute(ctx, 0, 2, 1, 3)
|
||||
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
|
||||
kq := key.MulmatFullPrec(ctx, query)
|
||||
kq = kq.Scale(ctx, scale)
|
||||
if mask != nil {
|
||||
kq = kq.Add(ctx, mask)
|
||||
}
|
||||
kq = kq.Softmax(ctx)
|
||||
kqv := value.Mulmat(ctx, kq)
|
||||
attention := kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2))
|
||||
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
hiddenStates = nn.Attention(ctx, query, key, value, nil, attention.WithMask(mask))
|
||||
hiddenStates = hiddenStates.Reshape(ctx, opts.hiddenSize, hiddenStates.Dim(2))
|
||||
return sa.Output.Forward(ctx, hiddenStates)
|
||||
}
|
||||
|
||||
// VisionMLP implements the multi-layer perceptron
|
||||
|
|
|
|||
|
|
@ -74,7 +74,7 @@ func (sa *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor,
|
|||
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
|
||||
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache)
|
||||
attention := nn.Attention(ctx, query, key, value, cache)
|
||||
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize)
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -66,7 +66,7 @@ func (sa *TextAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tens
|
|||
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
|
||||
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache)
|
||||
attention := nn.Attention(ctx, query, key, value, cache)
|
||||
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize)
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ func (sa *VisionAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Ten
|
|||
value := sa.Value.Forward(ctx, hiddenStates)
|
||||
value = value.Reshape(ctx, opts.headDim(), opts.numHeads, value.Dim(1))
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, math.Pow(float64(opts.headDim()), -0.5), nil)
|
||||
attention := nn.Attention(ctx, query, key, value, nil)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2))
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue