Summary
Dao (2023) extends FlashAttention to address a different bottleneck: while FlashAttention-1 eliminated the O(N²) HBM memory accesses of standard attention, it only reached 25–40% of the theoretical maximum FLOPs/s on A100 GPUs. The inefficiency stems from suboptimal work partitioning across GPU thread blocks and warps, causing low occupancy and unnecessary shared memory communication. FlashAttention-2 makes three targeted improvements: (1) reduces non-matmul FLOPs by restructuring the online softmax computation; (2) parallelizes attention computation across the sequence length dimension even for a single head, increasing GPU occupancy; (3) partitions work within each thread block across warps to minimize shared memory reads/writes.
These changes yield approximately 2× speedup over FlashAttention-1, reaching 50–73% of the theoretical maximum A100 FLOPs/s — close to the efficiency of highly optimized matrix multiplication (GEMM) kernels. End-to-end GPT-style model training achieves up to 225 TFLOPs/s per A100 (72% model FLOPs utilization). FlashAttention-2 became the standard attention implementation used in production LLM training frameworks (Megatron-LM, Mosaic, HuggingFace) and is the baseline against which subsequent attention optimizations (FlashAttention-3, ring attention, etc.) are measured.
Key Claims
- FlashAttention-2 reaches 50–73% of theoretical maximum FLOPs/s on A100, vs 25–40% for FlashAttention-1.
- ~2× wall-clock speedup over FlashAttention-1 on standard attention benchmarks.
- End-to-end GPT-style training throughput: up to 225 TFLOPs/s per A100 GPU (72% MFU).
- Supports sequence lengths up to 256K tokens (with sufficient HBM) with linear memory scaling.
- Backward pass is also re-implemented with improved parallelism, matching forward pass efficiency gains.
Methods
Three algorithmic changes over FlashAttention-1: (1) Non-matmul FLOP reduction: the online softmax rescaling in the forward pass is restructured so that the rescale factor is applied once per block rather than accumulated incrementally, reducing the number of non-matmul operations (which are slower than matmul on modern hardware). (2) Sequence parallelism: the outer loop over sequence blocks is moved to the outer dimension and assigned to separate thread blocks, allowing the attention for a single head to be parallelized across sequence length even for small batch sizes or short sequences — previously this was limited to batch/head parallelism. (3) Warp specialization: within each thread block, work is partitioned across warps so that key/value loads and query computation proceed with minimal cross-warp shared memory synchronization. The algorithm remains exact (no approximation) and produces bit-identical results to standard attention.
Failure modes
- FlashAttention-2 is optimized for A100/H100 GPU architecture; performance characteristics differ on other hardware (AMD, older NVIDIA, TPUs).
- Causal masking (autoregressive) attention is well-optimized but cross-attention (encoder-decoder) and non-causal attention patterns have lower utilization due to irregular tiling.
- Implementation complexity is high — writing correct FlashAttention kernels for new hardware or attention variants (e.g., sliding window, ALiBi) requires significant engineering effort.
- At very short sequence lengths (< 128 tokens), overhead of kernel launch and tiling setup reduces relative benefit vs. standard attention.
Connections
- flash-attention-fast-and-memory-efficient-exact-attention — FlashAttention-1, which this directly improves upon
- gqa-grouped-query-attention — GQA and FlashAttention-2 are jointly deployed; GQA reduces KV cache, FA2 speeds up the attention kernel
- mixtral-of-experts — Mixtral training uses FlashAttention-2
- llama-open-efficient-foundation-language-models — LLaMA-2 and subsequent models use FlashAttention-2 for training
- pagedattention-vllm — PagedAttention for efficient KV cache management complements FlashAttention-2 for inference
- flash-attention — the technique family
- attention — the operation being optimized
- inference-efficiency — broader category
- tri-dao — sole author
- stanford-hazy-research — affiliated lab
Citation
Dao, T. (2023). FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning. ICLR 2024. https://arxiv.org/abs/2307.08691