mirror of https://github.com/ollama/ollama
cleanup
This commit is contained in:
parent
de82b1f9a3
commit
baae175ebe
|
|
@ -33,7 +33,7 @@ type Backend interface {
|
|||
|
||||
// BackendCacheConfig should be implemented by backends that need special output
|
||||
// from the cache to meet specific requirements. It is frequently implemented in
|
||||
// conjunction with ScaledDotProductAttention.
|
||||
// conjunction with [nn.fastAttention].
|
||||
type BackendCacheConfig interface {
|
||||
CacheConfig() CacheConfig
|
||||
}
|
||||
|
|
@ -152,7 +152,6 @@ type Tensor interface {
|
|||
Div(ctx Context, t2 Tensor) Tensor
|
||||
|
||||
Mulmat(ctx Context, t2 Tensor) Tensor
|
||||
MulmatFullPrec(ctx Context, t2 Tensor) Tensor
|
||||
MulmatID(ctx Context, t2, ids Tensor) Tensor
|
||||
AddID(ctx Context, t2, ids Tensor) Tensor
|
||||
|
||||
|
|
@ -213,32 +212,6 @@ type Tensor interface {
|
|||
Interpolate(ctx Context, dims [4]int, samplingMode SamplingMode) Tensor
|
||||
}
|
||||
|
||||
// ScaledDotProductAttention implements a fused attention
|
||||
// operation equivalent to following code on a tensor named
|
||||
// query:
|
||||
//
|
||||
// query = query.Permute(ctx, 0, 2, 1, 3)
|
||||
// key = key.Permute(ctx, 0, 2, 1, 3)
|
||||
// value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
//
|
||||
// kq := key.MulmatFullPrec(ctx, query)
|
||||
//
|
||||
// kq = kq.Scale(ctx, scale)
|
||||
//
|
||||
// if mask != nil {
|
||||
// kq = kq.Add(ctx, mask)
|
||||
// }
|
||||
//
|
||||
// kq = kq.Softmax(ctx)
|
||||
//
|
||||
// kqv := value.Mulmat(ctx, kq)
|
||||
// return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
//
|
||||
// cacheConfigApplied indicates whether the optimizations requested through CacheConfig have been performed
|
||||
type ScaledDotProductAttention interface {
|
||||
ScaledDotProductAttention(ctx Context, key, value, mask, sinks Tensor, vmla Tensor, scale float64, cacheConfigApplied bool) Tensor
|
||||
}
|
||||
|
||||
type number interface {
|
||||
~int | ~int8 | ~int16 | ~int32 | ~int64 |
|
||||
~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 |
|
||||
|
|
|
|||
|
|
@ -1250,16 +1250,6 @@ func (t *Tensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
|||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
mul := C.ggml_mul_mat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t)
|
||||
C.ggml_mul_mat_set_prec(mul, C.GGML_PREC_F32)
|
||||
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: mul,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) MulmatID(ctx ml.Context, t2, ids ml.Tensor) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
|
|
@ -1650,75 +1640,6 @@ func (t *Tensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
|
|||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sinks ml.Tensor, vmla ml.Tensor, scale float64, cacheConfigApplied bool) ml.Tensor {
|
||||
// If the cache didn't help us with required transformations, do them here
|
||||
if !cacheConfigApplied {
|
||||
cacheConfig := t.b.CacheConfig()
|
||||
|
||||
// Padding key and value to CachePadding is a performance optimization, not a requirement, so we don't do it if it wasn't done by the caller
|
||||
|
||||
if cacheConfig.PermutedV {
|
||||
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
}
|
||||
|
||||
if mask != nil {
|
||||
padSize := int(pad(C.size_t(mask.Dim(1)), C.size_t(cacheConfig.MaskBatchPadding))) - mask.Dim(1)
|
||||
if padSize > 0 {
|
||||
mask = mask.Pad(ctx, 0, padSize, 0, 0)
|
||||
}
|
||||
|
||||
if mask.DType() != cacheConfig.MaskDType {
|
||||
mask = mask.Cast(ctx, cacheConfig.MaskDType)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var kqMask *C.struct_ggml_tensor
|
||||
if mask != nil {
|
||||
kqMask = mask.(*Tensor).t
|
||||
}
|
||||
|
||||
query := t.Permute(ctx, 0, 2, 1, 3)
|
||||
key = key.Permute(ctx, 0, 2, 1, 3)
|
||||
|
||||
if t.b.flashAttention == ml.FlashAttentionEnabled {
|
||||
value = value.Permute(ctx, 0, 2, 1, 3)
|
||||
|
||||
kqv := C.ggml_flash_attn_ext(ctx.(*Context).ctx, query.(*Tensor).t, key.(*Tensor).t, value.(*Tensor).t, kqMask, C.float(scale), 0, 0)
|
||||
if sinks != nil {
|
||||
C.ggml_flash_attn_ext_add_sinks(kqv, sinks.(*Tensor).t)
|
||||
}
|
||||
C.ggml_flash_attn_ext_set_prec(kqv, C.GGML_PREC_F32)
|
||||
|
||||
if vmla != nil {
|
||||
var cur ml.Tensor = &Tensor{b: t.b, t: kqv}
|
||||
cur = cur.Permute(ctx, 0, 2, 1, 3)
|
||||
cur = vmla.Mulmat(ctx, cur)
|
||||
cur = cur.Permute(ctx, 0, 2, 1, 3)
|
||||
cur = cur.Contiguous(ctx)
|
||||
kqv = cur.(*Tensor).t
|
||||
}
|
||||
|
||||
return &Tensor{b: t.b, t: kqv}
|
||||
} else {
|
||||
kq := key.MulmatFullPrec(ctx, query)
|
||||
kq = &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0),
|
||||
}
|
||||
if sinks != nil {
|
||||
C.ggml_soft_max_add_sinks(kq.(*Tensor).t, sinks.(*Tensor).t)
|
||||
}
|
||||
|
||||
kqv := value.Mulmat(ctx, kq)
|
||||
if vmla != nil {
|
||||
kqv = vmla.Mulmat(ctx, kqv)
|
||||
}
|
||||
|
||||
return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Duplicate(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
|
|
|
|||
Loading…
Reference in New Issue