The Problem
A single transformer layer in a frontier-scale model has matrices that do not fit on a single GPU. A 175B model has hidden dim 12288; the MLP up-projection has shape [12288, 49152]. In fp16 that single matrix is 1.2 GB; with gradients and Adam states it is 9.6 GB. Stack 96 layers and you have ~900 GB of just-the-MLPs, before attention or activations. Pipeline parallelism (split layers across GPUs) helps but cannot fix individual layers being too large. You need to split within a layer.
The Key Insight
The matrix multiplications inside a transformer layer can be split across GPUs by partitioning the weight matrices. Different GPUs compute different slices of the output. With careful placement of all-reduces, the result is mathematically identical to single-GPU computation.
The Megatron-LM design exploits the structure of the transformer: in the MLP, the first weight matrix is split column-wise, the second row-wise, and the GELU between them is element-wise (so it commutes with the column split). Only one all-reduce is needed at the very end of the MLP.
Mechanism in Plain English
MLP block (Y = GELU(X @ W1) @ W2):
- Split W1 by columns across N GPUs. Each GPU holds W1[:, k*m:(k+1)m] where m = 4hidden/N.
- Each GPU computes its slice: intermediate_k = GELU(X @ W1_slice).
- Split W2 by rows across the same N GPUs. Each GPU holds W2[k*m:(k+1)*m, :].
- Each GPU computes a partial sum: out_k = intermediate_k @ W2_slice.
- All-reduce across the N GPUs to sum the partials into the final output Y.
Attention block:
- Heads are split across GPUs. Each GPU owns a subset of heads.
- Each GPU computes its assigned heads’ attention independently (Q, K, V projections, softmax, weighted sum).
- Output projection: split row-wise like W2 above. All-reduce sums the partials.
Embedding / Output projection:
- Vocabulary split across GPUs. Each GPU computes logits for its slice.
- Cross-entropy loss requires sum over vocab; this needs an all-reduce of the partial log-sum-exp.
ASCII Diagram
MLP block, TP=4:
X (full input, replicated on all 4 GPUs)
|
+-------+-------+-------+-------+
| | | | |
GPU 0 GPU 1 GPU 2 GPU 3
W1[:,0:m] W1[:,m:2m] W1[:,2m:3m] W1[:,3m:4m] <- column split
| | | |
X@W1_0 X@W1_1 X@W1_2 X@W1_3 <- intermediate, no comm needed
| | | |
GELU GELU GELU GELU <- element-wise, no comm
| | | |
W2[0:m,:] W2[m:2m,:] W2[2m:3m,:] W2[3m:4m,:] <- row split
| | | |
partial partial partial partial <- partial sum of full output
| | | |
+-------+--------+-------+
|
ALL-REDUCE (sum across GPUs)
|
Y (full output, replicated)
Communication: 1 all-reduce per MLP block, of size [batch, seq, hidden].
What’s Clever
The choice of column-split for W1 then row-split for W2 is what eliminates intermediate communication. If you flipped the order (row-split W1, column-split W2), every GPU would need the full intermediate — requiring an all-gather between W1 and W2. The Megatron pattern avoids that.
The element-wise GELU is the crucial enabler. It commutes with the column split because it operates per-element. If you had a softmax along the hidden dimension between W1 and W2, this would not work.
For attention: heads are independent computations, so splitting heads across GPUs is naturally parallel. The KV-cache for each head lives entirely on one GPU.
Concrete Walkthrough
GPT-3 175B with TP=8:
Per-layer compute per GPU:
Attention: 12288 / 8 = 1536 features per head; 96 heads / 8 = 12 heads per GPU.
Q, K, V each: [seq, 1536] per head, 12 heads.
Softmax + weighted sum: per-head, no cross-GPU.
Output projection: row-split + all-reduce.
MLP: hidden 12288, expansion 4x = 49152, split 8-ways.
Per-GPU intermediate: [seq, 6144].
All-reduce at end: [batch, seq, 12288].
Communication per layer:
- 1 all-reduce after attention output projection
- 1 all-reduce after MLP
For batch=1, seq=2048, hidden=12288, fp16:
All-reduce size = 2048 * 12288 * 2 bytes = 50 MB per all-reduce.
Per layer: 100 MB sent (forward) + 100 MB (backward) = 200 MB.
96 layers: 19.2 GB per step.
At 300 GB/s NVLink: ~64 ms communication overhead per step.
If forward+backward compute is ~500 ms, communication overhead is 13%.
Going to TP=16 with cross-node communication (25 GB/s IB) would mean:
Per all-reduce: 50 MB / 25 GB/s = 2 ms (vs 0.17 ms over NVLink).
Per step: 200 ms communication. Same step but now 40% communication overhead.
TP across nodes is impractical.
This is why TP is restricted to within-node (typically degree 4 or 8 with NVLink).
Key Sources
- megatron-lm-training-multi-billion-parameter-language-models — the canonical formulation
Related Concepts
- model-parallel — parent category
- distributed-training — broader umbrella
- transformer — the architecture being parallelized
Open Questions
- Sequence parallelism. Activations are still replicated across TP ranks at all-reduce points. Sequence parallelism splits along the sequence axis to save memory, with extra communication.
- TP for non-MLP shapes. Models with different topology (MoE, mixture-of-depths) need different TP patterns.
- Auto-TP. Choosing the right TP degree per layer is mostly manual. Automated tools (Alpa, PyTorch DTensor) help but are not yet universal.