flash attn: add auto mode for llama engine (#13052)

* flash attn: add auto mode for llama engine

If the user does not specify fa in the environment, use auto-mode.

* review comments

* ensure kv cache quantized types have FA explicitly enabled

additional review comments
This commit is contained in:
Daniel Hiltgen 2025-12-12 13:27:19 -08:00 committed by GitHub
parent 3af5d3b738
commit bd6c1d6b49
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 101 additions and 25 deletions

View File

@ -13,6 +13,7 @@ import (
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/ollama/ollama/fs/util/bufioutil" "github.com/ollama/ollama/fs/util/bufioutil"
"github.com/ollama/ollama/ml"
) )
type GGML struct { type GGML struct {
@ -550,7 +551,7 @@ func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, error) {
}, nil }, nil
} }
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string, useFlashAttention bool) (kv []uint64, partialOffload, fullOffload uint64) { func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string, useFlashAttention ml.FlashAttentionType) (kv []uint64, partialOffload, fullOffload uint64) {
context *= uint64(numParallel) context *= uint64(numParallel)
embedding := f.KV().EmbeddingLength() embedding := f.KV().EmbeddingLength()
@ -791,7 +792,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
} }
partialOffload = 2 * f.KV().HeadCountMax() / cmp.Or(f.KV().HeadCountKVMin(), 1) * kvTotal / 6 partialOffload = 2 * f.KV().HeadCountMax() / cmp.Or(f.KV().HeadCountKVMin(), 1) * kvTotal / 6
if useFlashAttention { if useFlashAttention == ml.FlashAttentionEnabled {
// rough estimate of graph size with flash attention on // rough estimate of graph size with flash attention on
partialOffload = (4*uint64(numParallel) + context>>10 + 110) * format.MebiByte partialOffload = (4*uint64(numParallel) + context>>10 + 110) * format.MebiByte
} }
@ -809,6 +810,14 @@ func (f GGML) SupportsKVCacheType(cacheType string) bool {
return slices.Contains([]string{"q8_0", "q4_0"}, cacheType) return slices.Contains([]string{"q8_0", "q4_0"}, cacheType)
} }
// KVCacheTypeIsQuantized checks if the requested cache type is a quantized type
func (f GGML) KVCacheTypeIsQuantized(cacheType string) bool {
if cacheType == "" || cacheType == "f16" || cacheType == "f32" || cacheType == "bf16" {
return false
}
return true
}
// SupportsFlashAttention checks if the model supports flash attention // SupportsFlashAttention checks if the model supports flash attention
func (f GGML) SupportsFlashAttention() bool { func (f GGML) SupportsFlashAttention() bool {
_, isEmbedding := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())] _, isEmbedding := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]

View File

@ -118,7 +118,7 @@ type ContextParams struct {
c C.struct_llama_context_params c C.struct_llama_context_params
} }
func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, flashAttention bool, kvCacheType string) ContextParams { func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, flashAttention ml.FlashAttentionType, kvCacheType string) ContextParams {
params := C.llama_context_default_params() params := C.llama_context_default_params()
params.n_ctx = C.uint(numCtx) params.n_ctx = C.uint(numCtx)
params.n_batch = C.uint(batchSize * numSeqMax) params.n_batch = C.uint(batchSize * numSeqMax)
@ -127,10 +127,13 @@ func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, fla
params.n_threads = C.int(threads) params.n_threads = C.int(threads)
params.n_threads_batch = params.n_threads params.n_threads_batch = params.n_threads
params.embeddings = C.bool(true) params.embeddings = C.bool(true)
if flashAttention { switch flashAttention {
params.flash_attn_type = C.LLAMA_FLASH_ATTN_TYPE_ENABLED case ml.FlashAttentionEnabled:
} else { params.flash_attn_type = int32(C.LLAMA_FLASH_ATTN_TYPE_ENABLED)
params.flash_attn_type = C.LLAMA_FLASH_ATTN_TYPE_DISABLED case ml.FlashAttentionDisabled:
params.flash_attn_type = int32(C.LLAMA_FLASH_ATTN_TYPE_DISABLED)
case ml.FlashAttentionAuto:
params.flash_attn_type = int32(C.LLAMA_FLASH_ATTN_TYPE_AUTO)
} }
params.type_k = kvCacheTypeFromStr(strings.ToLower(kvCacheType)) params.type_k = kvCacheTypeFromStr(strings.ToLower(kvCacheType))
params.type_v = kvCacheTypeFromStr(strings.ToLower(kvCacheType)) params.type_v = kvCacheTypeFromStr(strings.ToLower(kvCacheType))

View File

@ -188,6 +188,11 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
if len(projectors) > 0 && llamaModel != nil { if len(projectors) > 0 && llamaModel != nil {
loadRequest.ProjectorPath = projectors[0] loadRequest.ProjectorPath = projectors[0]
} }
// Determine if the user has forced FA on or off
faUserSet := false
if envconfig.FlashAttention(true) == envconfig.FlashAttention(false) {
faUserSet = true
}
fa := envconfig.FlashAttention(f.FlashAttention()) fa := envconfig.FlashAttention(f.FlashAttention())
@ -205,19 +210,51 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
kvct := strings.ToLower(envconfig.KvCacheType()) kvct := strings.ToLower(envconfig.KvCacheType())
if fa { if textProcessor == nil {
slog.Info("enabling flash attention") flashAttention := ml.FlashAttentionAuto
loadRequest.FlashAttention = true if faUserSet {
if fa {
// Flash Attention also supports kv cache quantization flashAttention = ml.FlashAttentionEnabled
// Enable if the requested and kv cache type is supported by the model } else {
if f.SupportsKVCacheType(kvct) { flashAttention = ml.FlashAttentionDisabled
loadRequest.KvCacheType = kvct }
} else { }
slog.Warn("kv cache type not supported by model", "type", kvct)
if kvct != "" {
if f.KVCacheTypeIsQuantized(kvct) {
if flashAttention != ml.FlashAttentionEnabled {
slog.Warn("OLLAMA_FLASH_ATTENTION must be enabled to use a quantized OLLAMA_KV_CACHE_TYPE", "type", kvct)
loadRequest.KvCacheType = ""
} else if f.SupportsKVCacheType(kvct) {
loadRequest.KvCacheType = kvct
} else {
slog.Warn("unsupported OLLAMA_KV_CACHE_TYPE", "type", kvct)
}
} else {
if f.SupportsKVCacheType(kvct) {
loadRequest.KvCacheType = kvct
} else {
slog.Warn("unsupported OLLAMA_KV_CACHE_TYPE", "type", kvct)
}
}
}
loadRequest.FlashAttention = flashAttention
} else {
// For Ollama engine, use our SupportsFlashAttention logic
if fa {
slog.Info("enabling flash attention")
loadRequest.FlashAttention = ml.FlashAttentionEnabled
// Flash Attention also supports kv cache quantization
// Enable if the requested and kv cache type is supported by the model
if f.SupportsKVCacheType(kvct) {
loadRequest.KvCacheType = kvct
} else {
slog.Warn("kv cache type not supported by model", "type", kvct)
}
} else if kvct != "" && kvct != "f16" {
slog.Warn("quantized kv cache requested but flash attention disabled", "type", kvct)
} }
} else if kvct != "" && kvct != "f16" {
slog.Warn("quantized kv cache requested but flash attention disabled", "type", kvct)
} }
gpuLibs := ml.LibraryPaths(gpus) gpuLibs := ml.LibraryPaths(gpus)
@ -435,7 +472,7 @@ type LoadRequest struct {
LoraPath []string LoraPath []string
Parallel int Parallel int
BatchSize int BatchSize int
FlashAttention bool FlashAttention ml.FlashAttentionType
KvSize int KvSize int
KvCacheType string KvCacheType string
NumThreads int NumThreads int

View File

@ -74,7 +74,7 @@ type BackendParams struct {
GPULayers GPULayersList GPULayers GPULayersList
// FlashAttention indicates that we should use a fused flash attention kernel // FlashAttention indicates that we should use a fused flash attention kernel
FlashAttention bool FlashAttention FlashAttentionType
} }
var backends = make(map[string]func(string, BackendParams) (Backend, error)) var backends = make(map[string]func(string, BackendParams) (Backend, error))

View File

@ -109,7 +109,7 @@ type Backend struct {
// btDeviceMemory maps from a buffer type to the memory allocations associated with that device // btDeviceMemory maps from a buffer type to the memory allocations associated with that device
btDeviceMemory map[C.ggml_backend_buffer_type_t]*ml.DeviceMemory btDeviceMemory map[C.ggml_backend_buffer_type_t]*ml.DeviceMemory
flashAttention bool flashAttention ml.FlashAttentionType
// maxGraphNodes is the maximum allowed number of graph nodes in this scheduler // maxGraphNodes is the maximum allowed number of graph nodes in this scheduler
maxGraphNodes int maxGraphNodes int
@ -684,7 +684,7 @@ func (b *Backend) NewContextSize(n int) ml.Context {
} }
func (b *Backend) CacheConfig() ml.CacheConfig { func (b *Backend) CacheConfig() ml.CacheConfig {
if b.flashAttention { if b.flashAttention == ml.FlashAttentionEnabled {
return ml.CacheConfig{CachePadding: 256, MaskDType: ml.DTypeF16, MaskBatchPadding: C.GGML_KQ_MASK_PAD} return ml.CacheConfig{CachePadding: 256, MaskDType: ml.DTypeF16, MaskBatchPadding: C.GGML_KQ_MASK_PAD}
} else { } else {
return ml.CacheConfig{CachePadding: 256, PermutedV: true} return ml.CacheConfig{CachePadding: 256, PermutedV: true}
@ -1676,7 +1676,7 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sin
query := t.Permute(ctx, 0, 2, 1, 3) query := t.Permute(ctx, 0, 2, 1, 3)
key = key.Permute(ctx, 0, 2, 1, 3) key = key.Permute(ctx, 0, 2, 1, 3)
if t.b.flashAttention { if t.b.flashAttention == ml.FlashAttentionEnabled {
value = value.Permute(ctx, 0, 2, 1, 3) value = value.Permute(ctx, 0, 2, 1, 3)
kqv := C.ggml_flash_attn_ext(ctx.(*Context).ctx, query.(*Tensor).t, key.(*Tensor).t, value.(*Tensor).t, kqMask, C.float(scale), 0, 0) kqv := C.ggml_flash_attn_ext(ctx.(*Context).ctx, query.(*Tensor).t, key.(*Tensor).t, value.(*Tensor).t, kqMask, C.float(scale), 0, 0)

View File

@ -492,6 +492,32 @@ func FlashAttentionSupported(l []DeviceInfo) bool {
return true return true
} }
type FlashAttentionType int32
const (
// Aligned with llama_flash_attn_type
FlashAttentionAuto FlashAttentionType = -1
FlashAttentionDisabled FlashAttentionType = 0
FlashAttentionEnabled FlashAttentionType = 1
)
func (f FlashAttentionType) LogValue() slog.Value {
return slog.AnyValue(f.String())
}
func (f FlashAttentionType) String() string {
switch f {
case FlashAttentionAuto:
return "Auto"
case FlashAttentionDisabled:
return "Disabled"
case FlashAttentionEnabled:
return "Enabled"
default:
return "unknown"
}
}
// Given the list of GPUs this instantiation is targeted for, // Given the list of GPUs this instantiation is targeted for,
// figure out the visible devices environment variables // figure out the visible devices environment variables
// Set mustFilter true to enable filtering of CUDA devices // Set mustFilter true to enable filtering of CUDA devices

View File

@ -26,6 +26,7 @@ import (
"github.com/ollama/ollama/llama" "github.com/ollama/ollama/llama"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/logutil" "github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/runner/common" "github.com/ollama/ollama/runner/common"
) )
@ -832,7 +833,7 @@ func (s *Server) loadModel(
ppath string, ppath string,
kvSize int, kvSize int,
kvCacheType string, kvCacheType string,
flashAttention bool, flashAttention ml.FlashAttentionType,
threads int, threads int,
multiUserCache bool, multiUserCache bool,
) { ) {