From bd6c1d6b49aca86dbb1a59182b293c0d1f7b8db8 Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Fri, 12 Dec 2025 13:27:19 -0800 Subject: [PATCH] flash attn: add auto mode for llama engine (#13052) * flash attn: add auto mode for llama engine If the user does not specify fa in the environment, use auto-mode. * review comments * ensure kv cache quantized types have FA explicitly enabled additional review comments --- fs/ggml/ggml.go | 13 ++++++-- llama/llama.go | 13 +++++--- llm/server.go | 63 ++++++++++++++++++++++++++++-------- ml/backend.go | 2 +- ml/backend/ggml/ggml.go | 6 ++-- ml/device.go | 26 +++++++++++++++ runner/llamarunner/runner.go | 3 +- 7 files changed, 101 insertions(+), 25 deletions(-) diff --git a/fs/ggml/ggml.go b/fs/ggml/ggml.go index 4004bbfd9..691ea32b4 100644 --- a/fs/ggml/ggml.go +++ b/fs/ggml/ggml.go @@ -13,6 +13,7 @@ import ( "github.com/ollama/ollama/format" "github.com/ollama/ollama/fs/util/bufioutil" + "github.com/ollama/ollama/ml" ) type GGML struct { @@ -550,7 +551,7 @@ func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, error) { }, nil } -func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string, useFlashAttention bool) (kv []uint64, partialOffload, fullOffload uint64) { +func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string, useFlashAttention ml.FlashAttentionType) (kv []uint64, partialOffload, fullOffload uint64) { context *= uint64(numParallel) embedding := f.KV().EmbeddingLength() @@ -791,7 +792,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri } partialOffload = 2 * f.KV().HeadCountMax() / cmp.Or(f.KV().HeadCountKVMin(), 1) * kvTotal / 6 - if useFlashAttention { + if useFlashAttention == ml.FlashAttentionEnabled { // rough estimate of graph size with flash attention on partialOffload = (4*uint64(numParallel) + context>>10 + 110) * format.MebiByte } @@ -809,6 +810,14 @@ func (f GGML) SupportsKVCacheType(cacheType string) bool { return slices.Contains([]string{"q8_0", "q4_0"}, cacheType) } +// KVCacheTypeIsQuantized checks if the requested cache type is a quantized type +func (f GGML) KVCacheTypeIsQuantized(cacheType string) bool { + if cacheType == "" || cacheType == "f16" || cacheType == "f32" || cacheType == "bf16" { + return false + } + return true +} + // SupportsFlashAttention checks if the model supports flash attention func (f GGML) SupportsFlashAttention() bool { _, isEmbedding := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())] diff --git a/llama/llama.go b/llama/llama.go index 70bf3b9c3..49b3f56a6 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -118,7 +118,7 @@ type ContextParams struct { c C.struct_llama_context_params } -func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, flashAttention bool, kvCacheType string) ContextParams { +func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, flashAttention ml.FlashAttentionType, kvCacheType string) ContextParams { params := C.llama_context_default_params() params.n_ctx = C.uint(numCtx) params.n_batch = C.uint(batchSize * numSeqMax) @@ -127,10 +127,13 @@ func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, fla params.n_threads = C.int(threads) params.n_threads_batch = params.n_threads params.embeddings = C.bool(true) - if flashAttention { - params.flash_attn_type = C.LLAMA_FLASH_ATTN_TYPE_ENABLED - } else { - params.flash_attn_type = C.LLAMA_FLASH_ATTN_TYPE_DISABLED + switch flashAttention { + case ml.FlashAttentionEnabled: + params.flash_attn_type = int32(C.LLAMA_FLASH_ATTN_TYPE_ENABLED) + case ml.FlashAttentionDisabled: + params.flash_attn_type = int32(C.LLAMA_FLASH_ATTN_TYPE_DISABLED) + case ml.FlashAttentionAuto: + params.flash_attn_type = int32(C.LLAMA_FLASH_ATTN_TYPE_AUTO) } params.type_k = kvCacheTypeFromStr(strings.ToLower(kvCacheType)) params.type_v = kvCacheTypeFromStr(strings.ToLower(kvCacheType)) diff --git a/llm/server.go b/llm/server.go index abf6035dd..49af4e1b3 100644 --- a/llm/server.go +++ b/llm/server.go @@ -188,6 +188,11 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st if len(projectors) > 0 && llamaModel != nil { loadRequest.ProjectorPath = projectors[0] } + // Determine if the user has forced FA on or off + faUserSet := false + if envconfig.FlashAttention(true) == envconfig.FlashAttention(false) { + faUserSet = true + } fa := envconfig.FlashAttention(f.FlashAttention()) @@ -205,19 +210,51 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st kvct := strings.ToLower(envconfig.KvCacheType()) - if fa { - slog.Info("enabling flash attention") - loadRequest.FlashAttention = true - - // Flash Attention also supports kv cache quantization - // Enable if the requested and kv cache type is supported by the model - if f.SupportsKVCacheType(kvct) { - loadRequest.KvCacheType = kvct - } else { - slog.Warn("kv cache type not supported by model", "type", kvct) + if textProcessor == nil { + flashAttention := ml.FlashAttentionAuto + if faUserSet { + if fa { + flashAttention = ml.FlashAttentionEnabled + } else { + flashAttention = ml.FlashAttentionDisabled + } + } + + if kvct != "" { + if f.KVCacheTypeIsQuantized(kvct) { + if flashAttention != ml.FlashAttentionEnabled { + slog.Warn("OLLAMA_FLASH_ATTENTION must be enabled to use a quantized OLLAMA_KV_CACHE_TYPE", "type", kvct) + loadRequest.KvCacheType = "" + } else if f.SupportsKVCacheType(kvct) { + loadRequest.KvCacheType = kvct + } else { + slog.Warn("unsupported OLLAMA_KV_CACHE_TYPE", "type", kvct) + } + } else { + if f.SupportsKVCacheType(kvct) { + loadRequest.KvCacheType = kvct + } else { + slog.Warn("unsupported OLLAMA_KV_CACHE_TYPE", "type", kvct) + } + } + } + loadRequest.FlashAttention = flashAttention + } else { + // For Ollama engine, use our SupportsFlashAttention logic + if fa { + slog.Info("enabling flash attention") + loadRequest.FlashAttention = ml.FlashAttentionEnabled + + // Flash Attention also supports kv cache quantization + // Enable if the requested and kv cache type is supported by the model + if f.SupportsKVCacheType(kvct) { + loadRequest.KvCacheType = kvct + } else { + slog.Warn("kv cache type not supported by model", "type", kvct) + } + } else if kvct != "" && kvct != "f16" { + slog.Warn("quantized kv cache requested but flash attention disabled", "type", kvct) } - } else if kvct != "" && kvct != "f16" { - slog.Warn("quantized kv cache requested but flash attention disabled", "type", kvct) } gpuLibs := ml.LibraryPaths(gpus) @@ -435,7 +472,7 @@ type LoadRequest struct { LoraPath []string Parallel int BatchSize int - FlashAttention bool + FlashAttention ml.FlashAttentionType KvSize int KvCacheType string NumThreads int diff --git a/ml/backend.go b/ml/backend.go index 6e5a059ad..1e781fa7f 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -74,7 +74,7 @@ type BackendParams struct { GPULayers GPULayersList // FlashAttention indicates that we should use a fused flash attention kernel - FlashAttention bool + FlashAttention FlashAttentionType } var backends = make(map[string]func(string, BackendParams) (Backend, error)) diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 18bdc91eb..a50d8ec9c 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -109,7 +109,7 @@ type Backend struct { // btDeviceMemory maps from a buffer type to the memory allocations associated with that device btDeviceMemory map[C.ggml_backend_buffer_type_t]*ml.DeviceMemory - flashAttention bool + flashAttention ml.FlashAttentionType // maxGraphNodes is the maximum allowed number of graph nodes in this scheduler maxGraphNodes int @@ -684,7 +684,7 @@ func (b *Backend) NewContextSize(n int) ml.Context { } func (b *Backend) CacheConfig() ml.CacheConfig { - if b.flashAttention { + if b.flashAttention == ml.FlashAttentionEnabled { return ml.CacheConfig{CachePadding: 256, MaskDType: ml.DTypeF16, MaskBatchPadding: C.GGML_KQ_MASK_PAD} } else { return ml.CacheConfig{CachePadding: 256, PermutedV: true} @@ -1676,7 +1676,7 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sin query := t.Permute(ctx, 0, 2, 1, 3) key = key.Permute(ctx, 0, 2, 1, 3) - if t.b.flashAttention { + if t.b.flashAttention == ml.FlashAttentionEnabled { value = value.Permute(ctx, 0, 2, 1, 3) kqv := C.ggml_flash_attn_ext(ctx.(*Context).ctx, query.(*Tensor).t, key.(*Tensor).t, value.(*Tensor).t, kqMask, C.float(scale), 0, 0) diff --git a/ml/device.go b/ml/device.go index f892b512d..47e180d30 100644 --- a/ml/device.go +++ b/ml/device.go @@ -492,6 +492,32 @@ func FlashAttentionSupported(l []DeviceInfo) bool { return true } +type FlashAttentionType int32 + +const ( + // Aligned with llama_flash_attn_type + FlashAttentionAuto FlashAttentionType = -1 + FlashAttentionDisabled FlashAttentionType = 0 + FlashAttentionEnabled FlashAttentionType = 1 +) + +func (f FlashAttentionType) LogValue() slog.Value { + return slog.AnyValue(f.String()) +} + +func (f FlashAttentionType) String() string { + switch f { + case FlashAttentionAuto: + return "Auto" + case FlashAttentionDisabled: + return "Disabled" + case FlashAttentionEnabled: + return "Enabled" + default: + return "unknown" + } +} + // Given the list of GPUs this instantiation is targeted for, // figure out the visible devices environment variables // Set mustFilter true to enable filtering of CUDA devices diff --git a/runner/llamarunner/runner.go b/runner/llamarunner/runner.go index cb4bbe505..de9d718b3 100644 --- a/runner/llamarunner/runner.go +++ b/runner/llamarunner/runner.go @@ -26,6 +26,7 @@ import ( "github.com/ollama/ollama/llama" "github.com/ollama/ollama/llm" "github.com/ollama/ollama/logutil" + "github.com/ollama/ollama/ml" "github.com/ollama/ollama/runner/common" ) @@ -832,7 +833,7 @@ func (s *Server) loadModel( ppath string, kvSize int, kvCacheType string, - flashAttention bool, + flashAttention ml.FlashAttentionType, threads int, multiUserCache bool, ) {