Concepts: flash-attention | attention | inference-efficiency | memory-efficiency Builds on: flash-attention-fast-and-memory-efficient-exact-attention | attention-is-all-you-need
FlashAttention-1 fixed the memory wall: it stopped writing the N x N attention matrix to HBM. That gave a 2-4x speedup, but on A100 it was still only hitting 25-40% of the GPU’s peak FLOPs. The bottleneck had moved. The compute was now waiting on the GPU itself: not enough thread blocks were running, the warps inside each block were stepping on each other, and the rescaling step was burning cycles on operations that were not matrix multiplies. FlashAttention-2 is one author (Tri Dao) rewriting the kernel to fix all three issues at once. The result is roughly 2x faster than FA1, hitting 50-73% of peak A100 FLOPs.
The core idea
The analogy: FlashAttention-1 is a chef who stopped running back to the warehouse for ingredients (HBM). Now the bottleneck has shifted: only one burner of the eight-burner stove is on, the chef is ladling sauce by hand instead of using the food processor that’s right there, and the line cooks (warps) keep blocking each other in the cramped kitchen.
FA2 lights up all eight burners (more thread blocks running in parallel), assigns each line cook a clear lane (warps work on disjoint slices of work), and replaces hand-ladling with the food processor (re-orders the math so non-matmul operations happen rarely).
“We observe that the inefficiency is due to suboptimal work partitioning between different thread blocks and warps on the GPU, causing either low-occupancy or unnecessary shared memory reads/writes.”
The three fixes:
- Reduce non-matmul FLOPs. Modern GPUs run matmuls (via Tensor Cores) about 16x faster than general-purpose floating point. Even a small fraction of non-matmul work caps the achievable throughput. FA2 restructures the online softmax so that the per-block rescale is applied only once at the end of each tile, not on every accumulation step.
- Parallelize over sequence length. FA1 parallelized over (batch, head). For small batches or short sequences with few heads, this leaves most of the GPU’s streaming multiprocessors (SMs) idle. FA2 also parallelizes over sequence-length tiles, so a single attention head with one batch element can saturate the GPU.
- Better warp partitioning inside each block. Within a thread block, FA1 split work across warps in a way that required cross-warp communication via shared memory. FA2 reorganizes so each warp owns a distinct row slice; communication between warps drops near zero.
What’s clever — find the instinct
The non-obvious move is the diagnosis itself. FA1 was already considered close to optimal; the conventional wisdom in 2023 was “the matmul is dominant; non-matmul is a rounding error.” Tri Dao’s contribution is showing that on Tensor Core hardware, non-matmul is the dominant cost once HBM is no longer the bottleneck. Every cycle a warp spends on softmax rescaling is a cycle the Tensor Cores spend idle.
The numerical reformulation: in FA1, after each new tile (K_j, V_j) is processed, the output accumulator O_i has to be rescaled by exp(m_old - m_new) / l_new. That involves two divisions and two scalar multiplications per output element — non-matmul, slow.
FA2 keeps the rescaling but defers the per-element division until the very end. Internally it tracks a running un-normalized output and a running normalizer. Only at the end of the row does it divide. Same result, far fewer non-matmul operations.
“We tweak the algorithm to reduce the number of non-matmul FLOPs.”
The second clever move is swapping the inner and outer loops. In FA1, the outer loop is over Q tiles (the query rows) and the inner loop is over K/V tiles (the key/value columns). This is read-once-write-once for output, which seems optimal. But it forces all K/V tiles for a given Q tile to be processed by the same thread block — limiting parallelism.
FA2 swaps: outer loop over K/V tiles is removed in favor of partitioning Q tiles across thread blocks. Each thread block owns a slice of rows of the output and processes the entire K/V matrix for those rows. Different thread blocks work on different row slices, in parallel, without communicating. The number of thread blocks scales with sequence length, not just (batch, head), so SM occupancy stays high even for B=1, H=1, N=8K.
“We parallelize the attention computation, even for a single head, across different thread blocks to increase occupancy.”
The third clever move is the warp partitioning. Within a thread block, FA1 partitioned the K/V dimension across warps: warp 0 handled K[0:8], warp 1 handled K[8:16], etc. This required all warps to write partial results to shared memory and a reduction across warps to combine them. FA2 partitions the Q (output) dimension across warps instead: warp 0 owns rows 0..7 of the output, warp 1 owns rows 8..15. Each warp does a full pass over K/V independently. No reduction needed; no cross-warp shared memory traffic.
“Within each thread block, distribute the work between warps to reduce communication through shared memory.”
Walkthrough: where the speed comes from
Setup: A100 GPU. Batch=1, heads=8, seq_len=8192, head_dim=128.
FA1 layout:
Parallelism granularity: (batch x heads) = 1 x 8 = 8 thread blocks
A100 has 108 SMs.
Result: 8 / 108 = 7% SM utilization.
Most of the GPU is idle.
FA2 layout:
Parallelism granularity: (batch x heads x seq_tiles) = 1 x 8 x (8192/64) = 1024 thread blocks
Result: full SM utilization, more thread blocks than SMs (good — pipelined).
Non-matmul FLOPs (per tile, per row):
FA1: 1 max compare, 2 exps, 2 multiplies, 1 add, 1 division — repeated each tile.
Across all tiles: O(N) non-matmul ops per row.
FA2: 1 max, 2 exps, 1 multiply, 1 add — per tile. Final division done ONCE at end.
Reduction in non-matmul ops: ~30%.
Within thread block, warp partition (4 warps per block):
FA1: each warp processes K_block[w*32:(w+1)*32]; partial outputs combined via shared memory.
Cross-warp shared memory reads/writes per tile: O(64*64) bytes.
FA2: each warp processes Q_block[w*16:(w+1)*16] independently.
Cross-warp shared memory traffic: 0.
End-to-end speedup:
FA1: ~110 TFLOPs/s on A100 (35% of theoretical 312 TFLOPs/s for fp16 Tensor Core).
FA2: ~225 TFLOPs/s on A100 (72% of theoretical).
Roughly 2x faster.
The 72% MFU number is significant: GEMM on A100 at this size hits about 80-85%. FA2 is within striking distance of optimized matrix multiplication, which was unimaginable for attention three years earlier.
Does it work? What breaks?
Headline numbers (A100 80GB, fp16, head dim 128):
| Sequence Length | Standard | FlashAttention-1 | FlashAttention-2 |
|---|---|---|---|
| 512 | 89 TFLOPs/s | 110 TFLOPs/s | 187 TFLOPs/s |
| 1024 | 87 | 134 | 211 |
| 2048 | OOM | 159 | 225 |
| 8192 | OOM | 162 | 219 |
| 16384 | OOM | 159 | 220 |
End-to-end GPT-style training: 2.8x faster than baseline PyTorch attention. Memory still scales linearly in sequence length (the FA1 inheritance).
“FlashAttention-2 reaches 50-73% of the theoretical maximum FLOPs/s on A100 and gets close to the efficiency of GEMM operations.”
What breaks:
- Hardware specificity. FA2 is hand-tuned for Ampere (A100) and Hopper (H100) architectures. On older hardware (V100), AMD GPUs, or TPUs, the gains are smaller or absent.
- Cross-attention. Optimized for self-attention with causal masking. Encoder-decoder cross-attention has irregular tiling that FA2 handles less well.
- Custom attention patterns. Sliding window, ALiBi, custom masks all require new kernel variants. The community has built these (e.g., FlashAttention with sliding window for Mistral) but they are not free.
- Short sequences. At N < 128, kernel launch overhead dominates; standard attention is competitive.
- Backward pass. The backward kernel is even more complex and was not initially as optimized; subsequent FA2 versions improved it.
- Numerical edge cases. The deferred-division reformulation can underflow on extremely peaked attention distributions; rare but needs care.
So what?
For a practitioner training or serving transformer models:
- FA2 is the default attention kernel today. PyTorch’s
F.scaled_dot_product_attentiondispatches to FA2 on supported hardware. xFormers, Megatron, vLLM, and HuggingFace Transformers all use it under the hood. You should never be writing your own attention kernel. - The 2x speedup is end-to-end, not just attention. A 7B model that took 8 GPU-days to pretrain on FA1 takes ~4.5 GPU-days on FA2.
- Long-context training got affordable. 32K and 128K context windows became practical because FA2 keeps memory linear and throughput high. Prior to FA2, training on 32K sequences was a heroic engineering project.
- The lessons generalize. Future kernels (Tri Dao’s FlashAttention-3 on H100, ring attention for distributed long context) follow the same playbook: minimize non-matmul work, maximize SM occupancy, eliminate cross-warp synchronization.
- Kernel optimization is the new frontier. As Tensor Cores get faster (H100, B100), the relative cost of non-matmul work grows. Attention variants and any new layer types are worth implementing as fused kernels, not as composed PyTorch ops.
For an L5 systems interview: “what is the difference between FA1 and FA2” is asking whether you understand that the bottleneck moved. FA1 fixed memory bandwidth; FA2 fixed compute occupancy. Both are correct optimizations of the same algorithm; they target different tiers of the GPU’s performance hierarchy.
Connections
- flash-attention-fast-and-memory-efficient-exact-attention — FA1, the direct predecessor that this builds on
- attention-is-all-you-need — the operation being optimized
- gqa-grouped-query-attention — GQA is the standard companion: FA2 speeds up the kernel, GQA reduces the KV cache
- mixtral-of-experts — Mixtral training uses FA2
- llama-2-open-foundation-fine-tuned-chat-models — LLaMA-2 and downstream models use FA2 for training
- pagedattention-vllm — PagedAttention (KV-cache management) complements FA2 for inference
- flash-attention — the technique family
- attention — the operation
- inference-efficiency — broader category
- tri-dao — sole author
- stanford-hazy-research — affiliated lab
Citation
Dao, T. (2023). FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning. ICLR 2024. https://arxiv.org/abs/2307.08691