The Problem

A single transformer layer in a frontier-scale model has matrices that do not fit on a single GPU. A 175B model has hidden dim 12288; the MLP up-projection has shape [12288, 49152]. In fp16 that single matrix is 1.2 GB; with gradients and Adam states it is 9.6 GB. Stack 96 layers and you have ~900 GB of just-the-MLPs, before attention or activations. Pipeline parallelism (split layers across GPUs) helps but cannot fix individual layers being too large. You need to split within a layer.

The Key Insight

The matrix multiplications inside a transformer layer can be split across GPUs by partitioning the weight matrices. Different GPUs compute different slices of the output. With careful placement of all-reduces, the result is mathematically identical to single-GPU computation.

The Megatron-LM design exploits the structure of the transformer: in the MLP, the first weight matrix is split column-wise, the second row-wise, and the GELU between them is element-wise (so it commutes with the column split). Only one all-reduce is needed at the very end of the MLP.

Mechanism in Plain English

MLP block (Y = GELU(X @ W1) @ W2):

  1. Split W1 by columns across N GPUs. Each GPU holds W1[:, k*m:(k+1)m] where m = 4hidden/N.
  2. Each GPU computes its slice: intermediate_k = GELU(X @ W1_slice).
  3. Split W2 by rows across the same N GPUs. Each GPU holds W2[k*m:(k+1)*m, :].
  4. Each GPU computes a partial sum: out_k = intermediate_k @ W2_slice.
  5. All-reduce across the N GPUs to sum the partials into the final output Y.

Attention block:

  1. Heads are split across GPUs. Each GPU owns a subset of heads.
  2. Each GPU computes its assigned heads’ attention independently (Q, K, V projections, softmax, weighted sum).
  3. Output projection: split row-wise like W2 above. All-reduce sums the partials.

Embedding / Output projection:

  1. Vocabulary split across GPUs. Each GPU computes logits for its slice.
  2. Cross-entropy loss requires sum over vocab; this needs an all-reduce of the partial log-sum-exp.

ASCII Diagram

MLP block, TP=4:

           X (full input, replicated on all 4 GPUs)
           |
   +-------+-------+-------+-------+
   |       |       |       |       |
  GPU 0   GPU 1   GPU 2   GPU 3
  W1[:,0:m] W1[:,m:2m] W1[:,2m:3m] W1[:,3m:4m]   <- column split
   |       |       |       |
  X@W1_0  X@W1_1  X@W1_2  X@W1_3   <- intermediate, no comm needed
   |       |       |       |
  GELU   GELU   GELU   GELU         <- element-wise, no comm
   |       |       |       |
  W2[0:m,:] W2[m:2m,:] W2[2m:3m,:] W2[3m:4m,:]   <- row split
   |       |       |       |
  partial partial partial partial    <- partial sum of full output
   |       |       |       |
   +-------+--------+-------+
           |
       ALL-REDUCE (sum across GPUs)
           |
           Y (full output, replicated)

Communication: 1 all-reduce per MLP block, of size [batch, seq, hidden].

What’s Clever

The choice of column-split for W1 then row-split for W2 is what eliminates intermediate communication. If you flipped the order (row-split W1, column-split W2), every GPU would need the full intermediate — requiring an all-gather between W1 and W2. The Megatron pattern avoids that.

The element-wise GELU is the crucial enabler. It commutes with the column split because it operates per-element. If you had a softmax along the hidden dimension between W1 and W2, this would not work.

For attention: heads are independent computations, so splitting heads across GPUs is naturally parallel. The KV-cache for each head lives entirely on one GPU.

Concrete Walkthrough

GPT-3 175B with TP=8:

Per-layer compute per GPU:
  Attention: 12288 / 8 = 1536 features per head; 96 heads / 8 = 12 heads per GPU.
            Q, K, V each: [seq, 1536] per head, 12 heads.
            Softmax + weighted sum: per-head, no cross-GPU.
            Output projection: row-split + all-reduce.
  MLP: hidden 12288, expansion 4x = 49152, split 8-ways.
       Per-GPU intermediate: [seq, 6144].
       All-reduce at end: [batch, seq, 12288].

Communication per layer:
  - 1 all-reduce after attention output projection
  - 1 all-reduce after MLP

For batch=1, seq=2048, hidden=12288, fp16:
  All-reduce size = 2048 * 12288 * 2 bytes = 50 MB per all-reduce.
  Per layer: 100 MB sent (forward) + 100 MB (backward) = 200 MB.
  96 layers: 19.2 GB per step.
  At 300 GB/s NVLink: ~64 ms communication overhead per step.

If forward+backward compute is ~500 ms, communication overhead is 13%.
Going to TP=16 with cross-node communication (25 GB/s IB) would mean:
  Per all-reduce: 50 MB / 25 GB/s = 2 ms (vs 0.17 ms over NVLink).
  Per step: 200 ms communication. Same step but now 40% communication overhead.
  TP across nodes is impractical.

This is why TP is restricted to within-node (typically degree 4 or 8 with NVLink).

Key Sources

Open Questions

  • Sequence parallelism. Activations are still replicated across TP ranks at all-reduce points. Sequence parallelism splits along the sequence axis to save memory, with extra communication.
  • TP for non-MLP shapes. Models with different topology (MoE, mixture-of-depths) need different TP patterns.
  • Auto-TP. Choosing the right TP degree per layer is mostly manual. Automated tools (Alpa, PyTorch DTensor) help but are not yet universal.