The Problem

Standard attention requires materializing the N×N attention matrix — for N=8192 tokens, that’s 67M entries, 268MB in float32. This matrix is written to GPU HBM (high bandwidth memory, the off-chip DRAM), then read back multiple times to compute softmax, then to multiply by V.

The problem isn’t FLOPs. An A100 can do 312 TFLOPS. The problem is memory bandwidth: HBM transfers run at ~2 TB/s, but writing and reading that 268MB attention matrix multiple times eats the bandwidth budget before you’ve done much compute.

Standard attention uses ~40% of theoretical peak GPU FLOPs because it’s stalling on memory transfers, not running out of compute.

Analogy

A chef needs to prepare 100 dishes. The recipe has 3 steps: marinate, sear, sauce.

Naive approach: marinate all 100 dishes → put them in the walk-in → bring all 100 out → sear all 100 → put them back → bring all 100 out → sauce all 100. Lots of walking to the walk-in.

FlashAttention approach: take 8 dishes at a time. Marinate, sear, and sauce all 8 without putting them down. Then pick up the next 8. You never write to the walk-in for intermediate steps — only for the final output.

The walk-in is HBM. The station is SRAM (fast on-chip memory). The insight is that intermediate computation results (the attention matrix) don’t need to touch the walk-in at all if you’re clever about what you process together.

Mechanism in Plain English

The attention computation has three steps:

S = Q·Kᵀ           (similarity scores, N×N matrix)
P = softmax(S)      (normalized weights, N×N matrix)
O = P·V             (weighted values, N×d matrix)

The naive approach writes S and P to HBM between steps. FlashAttention fuses all three into one kernel:

  1. Tile Q, K, V into blocks that fit in SRAM (say, 64 rows at a time).

  2. For each block of Q rows:

    • Load the Q block into SRAM
    • For each block of K, V:
      • Load K, V blocks into SRAM
      • Compute the partial attention scores (small block of S)
      • Update the running softmax using the online softmax trick (see below)
      • Accumulate the weighted V sum
    • Write the output O block to HBM
  3. The N×N attention matrix S is never written to HBM. It lives briefly in SRAM and is immediately consumed.

Online softmax trick: softmax needs to divide by the sum of all exponentials, but you’re processing in blocks. You can maintain running statistics (current max , running sum ) and rescale the accumulated output whenever you see a block with a higher max. This makes the computation numerically equivalent to full softmax without requiring the full S matrix.

ASCII Diagram

  GPU memory hierarchy:
  
  ┌────────────────────────────────┐  ← HBM (off-chip, 40-80GB, ~2 TB/s)
  │  Q, K, V tensors               │    slow, but large
  │  Output O                      │
  └───────────────┬────────────────┘
                  │ load/store (expensive)
  ┌───────────────▼────────────────┐  ← SRAM (on-chip, ~20MB, ~19 TB/s)
  │  Q_block  K_block  V_block     │    fast, but tiny
  │  partial scores (small tile)   │
  │  running softmax stats (m, l)  │
  └────────────────────────────────┘

  Standard attention:
  
  Q,K → [HBM write] S (N×N) → [HBM read] → softmax → [HBM write] P → [HBM read] → P·V → O
                     ↑ 268MB of HBM traffic for N=8192

  FlashAttention:
  
  For each Q_block × K_block tile:
    [SRAM only: compute partial S, update running softmax, accumulate O]
  [HBM write] O (final output only)
  
  HBM writes: O(N) instead of O(N²). N×N matrix never exists in HBM.

Math: IO Complexity

Standard attention HBM accesses:

  • for reading Q, K, V (each )
  • for writing and reading the full attention matrix

FlashAttention HBM accesses:

  • is SRAM size (e.g., 20MB on A100)
  • For and : meaningfully less for moderate
  • The key: you’re reading tiles of size repeatedly, but never writing the matrix

Memory footprint:

  • Standard: — full attention matrix in HBM
  • FlashAttention: — only tiles in SRAM, final output in HBM

This is what makes long contexts possible. At , standard attention needs just for the attention matrix. FlashAttention’s footprint stays bounded by SRAM size.

Concrete Walkthrough

Sequence length N=4, head dim d=2, block size=2.

Q = [[1,0],[0,1],[1,1],[0,0]]
K = [[1,0],[0,1],[1,0],[0,1]]
V = [[10,0],[0,10],[5,5],[3,7]]

Block 1: Q_block = Q[0:2] = [[1,0],[0,1]]

  Inner loop, K_block[0:2] = K[0:2]:
  - Partial scores: S[0:2, 0:2] = Q[0:2]·K[0:2]ᵀ = [[1,0],[0,1]]
  - max so far: m = [1, 1]
  - Running sum: exp([[1,0],[0,1]] - [[1],[1]]) = [[1.0,0.37],[0.37,1.0]], row sums: [1.37, 1.37]
  - O contribution: [[1.0,0.37],[0.37,1.0]] · V[0:2] = [[10,3.7],[3.7,10]] (unnormalized)

  Inner loop, K_block[2:4] = K[2:4]:
  - Partial scores: S[0:2, 2:4] = [[1,0],[0,1]]
  - new max might be higher → rescale previous accumulation
  - Update running sum, accumulate V[2:4] contribution

After both inner loops, normalize by l (running sum) → output rows 0 and 1.

At no point did we write a 4×4 matrix to HBM. The partial 2×2 blocks lived and died in SRAM.

What’s Clever

The insight is recognizing attention is memory bandwidth bound, not compute bound.

Everyone assumed the bottleneck was the O(N²) FLOPs. Dao et al. measured and showed: with standard attention, the GPU is idle 60% of the time waiting for HBM reads. The O(N²) compute was finishing faster than the memory transfers.

The solution isn’t approximating attention (sparse attention, linear attention). It’s reorganizing exact attention to fit memory access patterns of modern GPU hardware. Same math, same result, completely different memory access pattern.

FlashAttention-2 (2023) addresses a second bottleneck: the original implementation only reached 25-40% of theoretical A100 FLOPs. Even without the N×N matrix bottleneck, parallelism across thread blocks was suboptimal. FA2 reaches 50-73% of peak FLOPs by rewriting the work partitioning across GPU warps — a systems insight, not an algorithmic one.

The backward pass deserves mention: it recomputes attention scores on-chip during the backward pass instead of saving them (materializing softmax outputs as activations). This trades compute for memory — acceptable because compute is cheap and activation memory is precious.

Key Sources

Open Questions

  • Extending to multi-query and grouped-query attention (largely done in FA2)
  • Adapting for non-standard attention patterns (sliding window, cross-attention)