Every few years someone publishes a paper that improves a standard algorithm by 3x without changing any of the math. Usually the trick is the same: everyone was optimizing for the wrong metric. For five years, researchers tried to speed up attention by making it approximate — doing less work. FlashAttention said: the work is fine. The problem is the commute.
The analogy
Picture a factory floor with two storage rooms: a tiny cabinet right next to the assembly line (fast to reach, holds 20 items), and a giant warehouse across the parking lot (holds thousands of items, but takes 12 minutes to walk there).
Standard attention looks like this: every step of production requires the manager to walk to the warehouse, bring back a massive stack of papers, do some math, walk back to deposit it, then walk again. The assembly workers are 90% idle, waiting for the manager’s warehouse runs.
FlashAttention says: stop making all those warehouse trips. Bring a small batch of papers to the assembly line cabinet. Do all the math you can with that batch. Then get the next batch. The workers stay busy. The warehouse trips drop from hundreds to tens.
The “assembly line cabinet” is GPU SRAM (on-chip memory, ~19 TB/s bandwidth). The “warehouse” is GPU HBM — the big memory you think of when you say “my GPU has 80 GB.” HBM is 12× slower. Standard attention runs back and forth between them constantly. FlashAttention mostly stays in SRAM.
The mechanism
First, understand the GPU memory hierarchy:
GPU Memory Hierarchy
─────────────────────────────────────────────────────────────
SRAM (on-chip, per-SM): ~20 MB total, ~19 TB/s bandwidth
↑ tiny but FAST
│ 12× faster bandwidth
HBM (off-chip DRAM): ~40–80 GB, ~1.5–2 TB/s bandwidth
↓ big but SLOW
─────────────────────────────────────────────────────────────
Standard attention does this:
Standard Attention — 6 HBM round-trips for a single layer:
─────────────────────────────────────────────────────────────
Step 1: Load Q, K from HBM → compute S = QKᵀ / √d [N×N matrix!]
Step 2: Write S to HBM
Step 3: Load S from HBM → compute P = softmax(S)
Step 4: Write P to HBM
Step 5: Load P, V from HBM → compute O = PV
Step 6: Write O to HBM
─────────────────────────────────────────────────────────────
HBM reads/writes: O(N²) — a full N×N matrix crosses the slow bus twice
For N=8,192 tokens (a modest long-context prompt), that’s 67 million floats crossing the slow bus — twice, for every transformer layer. With 32 layers and fp16, you’re moving 8+ GB of intermediate data per forward pass just for attention.
FlashAttention does this instead:
FlashAttention — tile-by-tile, mostly in SRAM:
─────────────────────────────────────────────────────────────
Divide Q into row tiles (Q₁, Q₂, ...), K/V into column tiles.
For each (Q tile, K tile) pair that fits in SRAM:
Load tiny Q_block, K_block, V_block → SRAM (fast!)
Compute partial QKᵀ scores in SRAM
Apply online softmax correction in SRAM
Accumulate into running output O
Discard partial results — never write to HBM
Write final O back to HBM once.
─────────────────────────────────────────────────────────────
HBM reads/writes: O(N²/M) where M = SRAM size
→ 5–10× fewer bytes moved for typical sequences
The N×N attention matrix is never materialized in HBM at all. It lives briefly in fast on-chip SRAM tile by tile, gets used, and is discarded. The only HBM write is the final output O — exactly what you needed anyway.
The math that matters
Standard softmax of attention scores:
P = softmax(S) where S = QKᵀ / √d
Computing softmax over a row requires seeing the entire row — you need the maximum value to subtract for numerical stability, and the sum of all exponentials for normalization. With standard attention, the whole N×N matrix S is in HBM.
FlashAttention uses the online softmax algorithm (Milakov & Gimelshein, 2018) to compute this incrementally without ever needing the full row:
Keep two running statistics per row:
m_i= running maximum seen so farℓ_i= running sum of shifted exponentials so far
When you see a new tile of scores, you update these statistics and rescale the previous partial output to account for the new maximum. The final result is mathematically identical to computing full softmax — no approximation.
The IO complexity analysis from the paper shows:
FlashAttention requires O(N²) FLOPs but O(N) extra memory and O(N²/M) HBM accesses
Same number of arithmetic operations as standard attention. Dramatically fewer memory transfers. When the bottleneck is bandwidth (which it is), fewer transfers = faster wall-clock time.
Walkthrough with actual numbers
Let’s trace through a tiny 4-token sequence with head dimension d=4.
Standard attention: Compute full 4×4 score matrix first:
S = QKᵀ × 0.5
Row 0 = [1.0, 0.0, 0.5, 0.5] → write to HBM
softmax([1.0, 0.0, 0.5, 0.5]):
max = 1.0
exp([0.0, -1.0, -0.5, -0.5]) = [1.0, 0.368, 0.607, 0.607]
sum = 2.582
P[0] = [0.387, 0.142, 0.235, 0.235] → write to HBM again
FlashAttention with block size 2 — process K tiles one at a time:
Process Q[0] with K[:2] (first tile):
Partial scores: s = [1.0, 0.0]
Running max: m = 1.0
Running sum: ℓ = e^(1.0-1.0) + e^(0.0-1.0) = 1.0 + 0.368 = 1.368
Partial output: o = [V[0]×1.0 + V[1]×0.368] / 1.368 ← tentative
Process Q[0] with K[2:] (second tile):
New scores: s = [0.5, 0.5]
New max: max(1.0, 0.5, 0.5) = 1.0 — no change!
New sum: ℓ = 1.368 + e^(0.5-1.0) + e^(0.5-1.0) = 1.368 + 0.607 + 0.607 = 2.582
Update output: rescale previous × (1.368/2.582) + new terms × (1/2.582)
Final P[0] = [0.387, 0.142, 0.235, 0.235] ← identical to standard!
Everything happened in SRAM. The full 4×4 score matrix was never written to HBM. At N=8K tokens, this means not writing a 64M-element matrix twice — that’s 256 MB of HBM traffic eliminated for a single layer.
What’s clever — find the instinct
The 2017–2022 ML research community knew attention was quadratic and spent years trying to make it less quadratic — sparse attention, linear attention, random feature approximations. Longformer, Linformer, Performer, BigBird. Every one of these sacrificed quality and most didn’t achieve wall-clock speedups on real hardware despite their theoretical improvements.
The instinct Tri Dao had: what if the algorithm is fine and the implementation is the problem?
This required thinking like a systems engineer, not an ML researcher. ML researchers think in computation graphs. Systems engineers think in cache lines. The question was: “Where does time actually go when I run attention on a GPU?” The answer was: not in the arithmetic. In the memory bus.
The recomputation insight is also worth dwelling on. During the backward pass for gradients, standard attention saves the N×N attention matrix P to HBM (it’s needed to compute gradients). FlashAttention instead recomputes P from saved Q and K during the backward pass. This trades some extra arithmetic for a massive reduction in memory. It works because arithmetic is cheap; HBM bandwidth is the bottleneck.
Real quotes from the paper
“We argue that a missing principle is making attention algorithms IO-aware — accounting for reads and writes between levels of GPU memory.”
Translation: Before this paper, there was no framework for even asking “how much memory bandwidth does attention use?” The FLOP count was the standard metric. IO-awareness is a new category of analysis that turns out to be the actually relevant one.
“FlashAttention requires fewer HBM accesses than standard attention, and is optimal for a range of SRAM sizes.”
Translation: It’s not just faster in practice — it’s proven optimal. The paper includes a formal lower bound showing you can’t do better given a fixed SRAM size. This is a theoretical result, not just engineering.
“We use tiling to prevent large N×N attention matrices from ever being instantiated in HBM.”
Translation: The key realization. You don’t need to hold the whole attention matrix at once. You can process it in small tiles, use the online softmax trick to accumulate correctly, and throw away each tile after use. The final answer is the same. The memory footprint during computation collapses from O(N²) to O(N).
“FlashAttention trains Transformers faster than existing baselines: 15% end-to-end wall-clock speedup on BERT-large… 3× speedup on GPT-2.”
Translation: These are real numbers on real workloads, not microbenchmarks. BERT-large on the MLPerf training benchmark: 15% faster. GPT-2: 3×. The arithmetic is unchanged; only the memory access pattern is.
Does it actually work?
| Model | Baseline | FlashAttention | Speedup |
|---|---|---|---|
| BERT-large (seq=512) | MLPerf 1.1 record | 15% faster end-to-end | 1.15× wall-clock |
| GPT-2 (seq=1K) | Standard PyTorch attention | 3× faster training | 3× |
| Long-range arena (seq=1K–4K) | Standard attention | 2.4× faster | 2.4× |
More striking: FlashAttention enabled the first Transformers to solve Path-X — a benchmark requiring 16,384-token context (61.4% accuracy). With standard attention, this was infeasible due to memory: the N×N matrix at N=16K is 1 billion floats, ~4 GB just for one layer. FlashAttention never materializes it, making 16K-context training possible on a single GPU.
What doesn’t work:
The original implementation requires hand-written CUDA kernels. It’s not easy to port to TPUs, AMD GPUs, or custom accelerators. Early adoption was blocked by the barrier of “who wants to maintain custom CUDA code?” (FlashAttention-2 and -3 have broadened hardware support substantially.)
It’s still O(N²) FLOPs. For truly long contexts (1M tokens), the quadratic compute still becomes a wall — FlashAttention buys constant factors, not asymptotic improvement.
Backward pass adds complexity: recomputing P during backward requires saving Q and K, increasing memory slightly versus a naive forward-only analysis.
So what?
If you’re training any Transformer today, you’re almost certainly using FlashAttention already — it’s been integrated into PyTorch’s torch.nn.functional.scaled_dot_product_attention, Hugging Face, and every major training framework.
The broader lesson: when models feel slow or memory-hungry, ask what the GPU is actually doing, not just what the algorithm says it should do. The gap between theoretical FLOP count and actual hardware utilization is where most practical speedups hide. PagedAttention found this same gap on the inference side. The pattern repeats: the obvious bottleneck (the math) isn’t the bottleneck (the memory traffic).
Compute is cheap. Moving data is expensive. If your algorithm is moving data for no good reason, someone will eventually write a paper showing you how to stop — and they’ll get a 3× speedup without changing a single gradient.
Connections
- flash-attention — the technique introduced here
- attention — the operation being optimized
- kv-cache — KV cache efficiency connects to this
- speculative-decoding — shares the memory-bandwidth bottleneck insight
Citation
Dao, T., Fu, D. Y., Ermon, S., Rudra, A., & Ré, C. (2022). FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. NeurIPS 2022. https://arxiv.org/abs/2205.14135