mirror of https://github.com/ollama/ollama
flash attn: add auto mode for llama engine (#13052)
* flash attn: add auto mode for llama engine If the user does not specify fa in the environment, use auto-mode. * review comments * ensure kv cache quantized types have FA explicitly enabled additional review comments
This commit is contained in:
parent
3af5d3b738
commit
bd6c1d6b49
|
|
@ -13,6 +13,7 @@ import (
|
||||||
|
|
||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
"github.com/ollama/ollama/fs/util/bufioutil"
|
"github.com/ollama/ollama/fs/util/bufioutil"
|
||||||
|
"github.com/ollama/ollama/ml"
|
||||||
)
|
)
|
||||||
|
|
||||||
type GGML struct {
|
type GGML struct {
|
||||||
|
|
@ -550,7 +551,7 @@ func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, error) {
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string, useFlashAttention bool) (kv []uint64, partialOffload, fullOffload uint64) {
|
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string, useFlashAttention ml.FlashAttentionType) (kv []uint64, partialOffload, fullOffload uint64) {
|
||||||
context *= uint64(numParallel)
|
context *= uint64(numParallel)
|
||||||
|
|
||||||
embedding := f.KV().EmbeddingLength()
|
embedding := f.KV().EmbeddingLength()
|
||||||
|
|
@ -791,7 +792,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
|
||||||
}
|
}
|
||||||
|
|
||||||
partialOffload = 2 * f.KV().HeadCountMax() / cmp.Or(f.KV().HeadCountKVMin(), 1) * kvTotal / 6
|
partialOffload = 2 * f.KV().HeadCountMax() / cmp.Or(f.KV().HeadCountKVMin(), 1) * kvTotal / 6
|
||||||
if useFlashAttention {
|
if useFlashAttention == ml.FlashAttentionEnabled {
|
||||||
// rough estimate of graph size with flash attention on
|
// rough estimate of graph size with flash attention on
|
||||||
partialOffload = (4*uint64(numParallel) + context>>10 + 110) * format.MebiByte
|
partialOffload = (4*uint64(numParallel) + context>>10 + 110) * format.MebiByte
|
||||||
}
|
}
|
||||||
|
|
@ -809,6 +810,14 @@ func (f GGML) SupportsKVCacheType(cacheType string) bool {
|
||||||
return slices.Contains([]string{"q8_0", "q4_0"}, cacheType)
|
return slices.Contains([]string{"q8_0", "q4_0"}, cacheType)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// KVCacheTypeIsQuantized checks if the requested cache type is a quantized type
|
||||||
|
func (f GGML) KVCacheTypeIsQuantized(cacheType string) bool {
|
||||||
|
if cacheType == "" || cacheType == "f16" || cacheType == "f32" || cacheType == "bf16" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
// SupportsFlashAttention checks if the model supports flash attention
|
// SupportsFlashAttention checks if the model supports flash attention
|
||||||
func (f GGML) SupportsFlashAttention() bool {
|
func (f GGML) SupportsFlashAttention() bool {
|
||||||
_, isEmbedding := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]
|
_, isEmbedding := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]
|
||||||
|
|
|
||||||
|
|
@ -118,7 +118,7 @@ type ContextParams struct {
|
||||||
c C.struct_llama_context_params
|
c C.struct_llama_context_params
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, flashAttention bool, kvCacheType string) ContextParams {
|
func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, flashAttention ml.FlashAttentionType, kvCacheType string) ContextParams {
|
||||||
params := C.llama_context_default_params()
|
params := C.llama_context_default_params()
|
||||||
params.n_ctx = C.uint(numCtx)
|
params.n_ctx = C.uint(numCtx)
|
||||||
params.n_batch = C.uint(batchSize * numSeqMax)
|
params.n_batch = C.uint(batchSize * numSeqMax)
|
||||||
|
|
@ -127,10 +127,13 @@ func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, fla
|
||||||
params.n_threads = C.int(threads)
|
params.n_threads = C.int(threads)
|
||||||
params.n_threads_batch = params.n_threads
|
params.n_threads_batch = params.n_threads
|
||||||
params.embeddings = C.bool(true)
|
params.embeddings = C.bool(true)
|
||||||
if flashAttention {
|
switch flashAttention {
|
||||||
params.flash_attn_type = C.LLAMA_FLASH_ATTN_TYPE_ENABLED
|
case ml.FlashAttentionEnabled:
|
||||||
} else {
|
params.flash_attn_type = int32(C.LLAMA_FLASH_ATTN_TYPE_ENABLED)
|
||||||
params.flash_attn_type = C.LLAMA_FLASH_ATTN_TYPE_DISABLED
|
case ml.FlashAttentionDisabled:
|
||||||
|
params.flash_attn_type = int32(C.LLAMA_FLASH_ATTN_TYPE_DISABLED)
|
||||||
|
case ml.FlashAttentionAuto:
|
||||||
|
params.flash_attn_type = int32(C.LLAMA_FLASH_ATTN_TYPE_AUTO)
|
||||||
}
|
}
|
||||||
params.type_k = kvCacheTypeFromStr(strings.ToLower(kvCacheType))
|
params.type_k = kvCacheTypeFromStr(strings.ToLower(kvCacheType))
|
||||||
params.type_v = kvCacheTypeFromStr(strings.ToLower(kvCacheType))
|
params.type_v = kvCacheTypeFromStr(strings.ToLower(kvCacheType))
|
||||||
|
|
|
||||||
|
|
@ -188,6 +188,11 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
|
||||||
if len(projectors) > 0 && llamaModel != nil {
|
if len(projectors) > 0 && llamaModel != nil {
|
||||||
loadRequest.ProjectorPath = projectors[0]
|
loadRequest.ProjectorPath = projectors[0]
|
||||||
}
|
}
|
||||||
|
// Determine if the user has forced FA on or off
|
||||||
|
faUserSet := false
|
||||||
|
if envconfig.FlashAttention(true) == envconfig.FlashAttention(false) {
|
||||||
|
faUserSet = true
|
||||||
|
}
|
||||||
|
|
||||||
fa := envconfig.FlashAttention(f.FlashAttention())
|
fa := envconfig.FlashAttention(f.FlashAttention())
|
||||||
|
|
||||||
|
|
@ -205,19 +210,51 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
|
||||||
|
|
||||||
kvct := strings.ToLower(envconfig.KvCacheType())
|
kvct := strings.ToLower(envconfig.KvCacheType())
|
||||||
|
|
||||||
if fa {
|
if textProcessor == nil {
|
||||||
slog.Info("enabling flash attention")
|
flashAttention := ml.FlashAttentionAuto
|
||||||
loadRequest.FlashAttention = true
|
if faUserSet {
|
||||||
|
if fa {
|
||||||
// Flash Attention also supports kv cache quantization
|
flashAttention = ml.FlashAttentionEnabled
|
||||||
// Enable if the requested and kv cache type is supported by the model
|
} else {
|
||||||
if f.SupportsKVCacheType(kvct) {
|
flashAttention = ml.FlashAttentionDisabled
|
||||||
loadRequest.KvCacheType = kvct
|
}
|
||||||
} else {
|
}
|
||||||
slog.Warn("kv cache type not supported by model", "type", kvct)
|
|
||||||
|
if kvct != "" {
|
||||||
|
if f.KVCacheTypeIsQuantized(kvct) {
|
||||||
|
if flashAttention != ml.FlashAttentionEnabled {
|
||||||
|
slog.Warn("OLLAMA_FLASH_ATTENTION must be enabled to use a quantized OLLAMA_KV_CACHE_TYPE", "type", kvct)
|
||||||
|
loadRequest.KvCacheType = ""
|
||||||
|
} else if f.SupportsKVCacheType(kvct) {
|
||||||
|
loadRequest.KvCacheType = kvct
|
||||||
|
} else {
|
||||||
|
slog.Warn("unsupported OLLAMA_KV_CACHE_TYPE", "type", kvct)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if f.SupportsKVCacheType(kvct) {
|
||||||
|
loadRequest.KvCacheType = kvct
|
||||||
|
} else {
|
||||||
|
slog.Warn("unsupported OLLAMA_KV_CACHE_TYPE", "type", kvct)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
loadRequest.FlashAttention = flashAttention
|
||||||
|
} else {
|
||||||
|
// For Ollama engine, use our SupportsFlashAttention logic
|
||||||
|
if fa {
|
||||||
|
slog.Info("enabling flash attention")
|
||||||
|
loadRequest.FlashAttention = ml.FlashAttentionEnabled
|
||||||
|
|
||||||
|
// Flash Attention also supports kv cache quantization
|
||||||
|
// Enable if the requested and kv cache type is supported by the model
|
||||||
|
if f.SupportsKVCacheType(kvct) {
|
||||||
|
loadRequest.KvCacheType = kvct
|
||||||
|
} else {
|
||||||
|
slog.Warn("kv cache type not supported by model", "type", kvct)
|
||||||
|
}
|
||||||
|
} else if kvct != "" && kvct != "f16" {
|
||||||
|
slog.Warn("quantized kv cache requested but flash attention disabled", "type", kvct)
|
||||||
}
|
}
|
||||||
} else if kvct != "" && kvct != "f16" {
|
|
||||||
slog.Warn("quantized kv cache requested but flash attention disabled", "type", kvct)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
gpuLibs := ml.LibraryPaths(gpus)
|
gpuLibs := ml.LibraryPaths(gpus)
|
||||||
|
|
@ -435,7 +472,7 @@ type LoadRequest struct {
|
||||||
LoraPath []string
|
LoraPath []string
|
||||||
Parallel int
|
Parallel int
|
||||||
BatchSize int
|
BatchSize int
|
||||||
FlashAttention bool
|
FlashAttention ml.FlashAttentionType
|
||||||
KvSize int
|
KvSize int
|
||||||
KvCacheType string
|
KvCacheType string
|
||||||
NumThreads int
|
NumThreads int
|
||||||
|
|
|
||||||
|
|
@ -74,7 +74,7 @@ type BackendParams struct {
|
||||||
GPULayers GPULayersList
|
GPULayers GPULayersList
|
||||||
|
|
||||||
// FlashAttention indicates that we should use a fused flash attention kernel
|
// FlashAttention indicates that we should use a fused flash attention kernel
|
||||||
FlashAttention bool
|
FlashAttention FlashAttentionType
|
||||||
}
|
}
|
||||||
|
|
||||||
var backends = make(map[string]func(string, BackendParams) (Backend, error))
|
var backends = make(map[string]func(string, BackendParams) (Backend, error))
|
||||||
|
|
|
||||||
|
|
@ -109,7 +109,7 @@ type Backend struct {
|
||||||
// btDeviceMemory maps from a buffer type to the memory allocations associated with that device
|
// btDeviceMemory maps from a buffer type to the memory allocations associated with that device
|
||||||
btDeviceMemory map[C.ggml_backend_buffer_type_t]*ml.DeviceMemory
|
btDeviceMemory map[C.ggml_backend_buffer_type_t]*ml.DeviceMemory
|
||||||
|
|
||||||
flashAttention bool
|
flashAttention ml.FlashAttentionType
|
||||||
|
|
||||||
// maxGraphNodes is the maximum allowed number of graph nodes in this scheduler
|
// maxGraphNodes is the maximum allowed number of graph nodes in this scheduler
|
||||||
maxGraphNodes int
|
maxGraphNodes int
|
||||||
|
|
@ -684,7 +684,7 @@ func (b *Backend) NewContextSize(n int) ml.Context {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Backend) CacheConfig() ml.CacheConfig {
|
func (b *Backend) CacheConfig() ml.CacheConfig {
|
||||||
if b.flashAttention {
|
if b.flashAttention == ml.FlashAttentionEnabled {
|
||||||
return ml.CacheConfig{CachePadding: 256, MaskDType: ml.DTypeF16, MaskBatchPadding: C.GGML_KQ_MASK_PAD}
|
return ml.CacheConfig{CachePadding: 256, MaskDType: ml.DTypeF16, MaskBatchPadding: C.GGML_KQ_MASK_PAD}
|
||||||
} else {
|
} else {
|
||||||
return ml.CacheConfig{CachePadding: 256, PermutedV: true}
|
return ml.CacheConfig{CachePadding: 256, PermutedV: true}
|
||||||
|
|
@ -1676,7 +1676,7 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sin
|
||||||
query := t.Permute(ctx, 0, 2, 1, 3)
|
query := t.Permute(ctx, 0, 2, 1, 3)
|
||||||
key = key.Permute(ctx, 0, 2, 1, 3)
|
key = key.Permute(ctx, 0, 2, 1, 3)
|
||||||
|
|
||||||
if t.b.flashAttention {
|
if t.b.flashAttention == ml.FlashAttentionEnabled {
|
||||||
value = value.Permute(ctx, 0, 2, 1, 3)
|
value = value.Permute(ctx, 0, 2, 1, 3)
|
||||||
|
|
||||||
kqv := C.ggml_flash_attn_ext(ctx.(*Context).ctx, query.(*Tensor).t, key.(*Tensor).t, value.(*Tensor).t, kqMask, C.float(scale), 0, 0)
|
kqv := C.ggml_flash_attn_ext(ctx.(*Context).ctx, query.(*Tensor).t, key.(*Tensor).t, value.(*Tensor).t, kqMask, C.float(scale), 0, 0)
|
||||||
|
|
|
||||||
26
ml/device.go
26
ml/device.go
|
|
@ -492,6 +492,32 @@ func FlashAttentionSupported(l []DeviceInfo) bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type FlashAttentionType int32
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Aligned with llama_flash_attn_type
|
||||||
|
FlashAttentionAuto FlashAttentionType = -1
|
||||||
|
FlashAttentionDisabled FlashAttentionType = 0
|
||||||
|
FlashAttentionEnabled FlashAttentionType = 1
|
||||||
|
)
|
||||||
|
|
||||||
|
func (f FlashAttentionType) LogValue() slog.Value {
|
||||||
|
return slog.AnyValue(f.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f FlashAttentionType) String() string {
|
||||||
|
switch f {
|
||||||
|
case FlashAttentionAuto:
|
||||||
|
return "Auto"
|
||||||
|
case FlashAttentionDisabled:
|
||||||
|
return "Disabled"
|
||||||
|
case FlashAttentionEnabled:
|
||||||
|
return "Enabled"
|
||||||
|
default:
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Given the list of GPUs this instantiation is targeted for,
|
// Given the list of GPUs this instantiation is targeted for,
|
||||||
// figure out the visible devices environment variables
|
// figure out the visible devices environment variables
|
||||||
// Set mustFilter true to enable filtering of CUDA devices
|
// Set mustFilter true to enable filtering of CUDA devices
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,7 @@ import (
|
||||||
"github.com/ollama/ollama/llama"
|
"github.com/ollama/ollama/llama"
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
"github.com/ollama/ollama/logutil"
|
"github.com/ollama/ollama/logutil"
|
||||||
|
"github.com/ollama/ollama/ml"
|
||||||
"github.com/ollama/ollama/runner/common"
|
"github.com/ollama/ollama/runner/common"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -832,7 +833,7 @@ func (s *Server) loadModel(
|
||||||
ppath string,
|
ppath string,
|
||||||
kvSize int,
|
kvSize int,
|
||||||
kvCacheType string,
|
kvCacheType string,
|
||||||
flashAttention bool,
|
flashAttention ml.FlashAttentionType,
|
||||||
threads int,
|
threads int,
|
||||||
multiUserCache bool,
|
multiUserCache bool,
|
||||||
) {
|
) {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue