Concepts: model-parallel | tensor-parallel | distributed-training | memory-efficiency Builds on: attention-is-all-you-need Leads to: zero-memory-optimizations-trillion-parameter-models
In 2019 the consensus way to scale a neural network past one GPU was data parallelism: replicate the whole model on every GPU, run different batches in parallel, all-reduce the gradients. The cap is reached when the model itself stops fitting on one GPU. For a transformer with 8B parameters in fp16, the parameters alone are 16GB; throw in optimizer states and activations and you are at 80GB+, beyond what a single V100 (32GB) or even A100 (40-80GB) holds. Megatron-LM from NVIDIA proposes the alternative: split each individual layer’s matrix multiplications across GPUs, with carefully placed communication primitives so the math still works.
The core idea
The analogy: Imagine multiplying two enormous matrices A x B, where A is too big to hold in any single computer’s memory. The naive solution is to give each computer the whole problem and split the workload (data parallelism), but that requires every computer to have the full matrix in memory. The Megatron solution: cut A into vertical slices, give each slice to a different computer. Each computer multiplies its slice by B independently. The results, when concatenated, equal the full A x B. No machine ever held the whole A.
Apply this to a transformer’s two big tensors:
- MLP block. The MLP is
Y = GELU(X @ W1) @ W2. The first weight matrix W1 has shape[hidden, 4*hidden]. Megatron splits W1 column-wise across N GPUs: each GPU stores[hidden, 4*hidden / N]. Each GPU computes its slice of the GELU output independently. W2 has shape[4*hidden, hidden]; Megatron splits it row-wise across the same N GPUs. The matmul(intermediate_slice) @ W2_slicegives a partial sum on each GPU. An all-reduce produces the final output Y. - Attention block. Each head has its own Q, K, V projections. Megatron splits the heads across GPUs: GPU 0 holds heads 0..H/N-1, GPU 1 holds heads H/N..2H/N-1, etc. Each GPU computes its assigned heads’ attention independently. The concat-then-output-projection is split with the same column-then-row pattern as the MLP, so it ends with one all-reduce.
“We present our techniques for training very large transformer models and implement a simple, efficient intra-layer model parallel approach that enables training transformer models with billions of parameters.”
The total communication per layer is two all-reduces (one for MLP output, one for attention output). The total memory per GPU is divided by N for both parameters and activations of the parallelized dimensions.
What’s clever — find the instinct
The non-obvious move is the partition direction matters. If you split W1 row-wise instead of column-wise, every GPU needs the full input X (you have to all-gather or replicate). That doubles communication. Splitting W1 column-wise lets each GPU read its own column slice and produce a column slice of the intermediate; W2 is then naturally split row-wise to feed those columns; the only required communication is at the very end of the MLP (an all-reduce on the partial sums).
This is what the paper calls “f and g operators”: one no-op identity (f) and one all-reduce (g) per parallelizable block, placed at the input and output respectively in forward; reverse order in backward. The mathematical identity that makes it work:
If W1 = [W1_a | W1_b] (column split) and W2 = [W2_a; W2_b] (row split, stacked):
X @ W1 = [X @ W1_a | X @ W1_b]
GELU(X @ W1) = [GELU(X @ W1_a) | GELU(X @ W1_b)] (GELU is element-wise)
GELU(X @ W1) @ W2 = GELU(X @ W1_a) @ W2_a + GELU(X @ W1_b) @ W2_b (sum of partials)
So:
Each GPU computes (GELU(X @ W1_slice) @ W2_slice) → a partial sum
All-reduce sums them → the full result
No intermediate communication.
Element-wise nonlinearities like GELU are the key enabler. They commute with the column split. If GELU were replaced by a softmax across the hidden dimension, this would not work; you would need a cross-GPU softmax, which is much more expensive.
“Our approach does not require a new compiler or library changes, is orthogonal and complimentary to pipeline model parallelism, and can be fully implemented with the insertion of a few communication operations in native PyTorch.”
The second clever move is layer norm placement. The original transformer used post-norm (layer norm after the residual). Megatron found that post-norm is unstable for very large models; gradient norms grow uncontrollably. They switch to pre-norm (layer norm inside the residual branch) and observe that 8B+ models train cleanly with this change. This is now the universal default — every modern transformer (GPT-J, LLaMA, Mistral, etc.) uses pre-norm.
“We show that careful attention to the placement of layer normalization in BERT-like models is critical to achieving increased performance as the model size grows.”
The third clever move: vocabulary parallelism. The output projection’s weight is [hidden, vocab_size]. Vocab is often 32K-256K — large enough that the projection alone is a substantial chunk of memory. Megatron splits the vocab across GPUs, has each compute the logits for its slice, and uses a parallel cross-entropy loss that requires only one all-reduce of the partition-function (log-sum-exp normalizer) to compute the global softmax denominator.
Walkthrough: 8.3B model on 8 V100s
Setup: GPT-2-style 8.3B parameter transformer.
V100 32GB GPUs.
8 GPUs in one tensor-parallel group.
Memory per GPU (no parallelism):
Parameters (fp16): 8.3B * 2 = 16.6 GB
Gradients (fp16): 16.6 GB
Optimizer (fp32, Adam): 8.3B * 12 = 99.6 GB
Activations (peak): ~10 GB
TOTAL: ~143 GB
V100 has 32 GB → out of memory.
With Megatron 8-way tensor parallel:
Parameters per GPU: 16.6 / 8 = 2.1 GB
Gradients per GPU: 2.1 GB
Optimizer per GPU: 99.6 / 8 = 12.5 GB
Activations per GPU: ~10 GB / 8 = 1.25 GB (mostly partitioned)
TOTAL per GPU: ~18 GB
Fits in V100.
Compute pattern per layer (forward, MLP block):
Each GPU has X (full input, replicated): shape [B, seq, hidden].
Each GPU has W1_slice: [hidden, 4*hidden / 8].
GPU 0 computes: GELU(X @ W1_slice_0) → [B, seq, 4*hidden/8]
GPU 1 computes: GELU(X @ W1_slice_1) → [B, seq, 4*hidden/8]
...
Each GPU has W2_slice: [4*hidden / 8, hidden].
GPU 0 computes: intermediate_0 @ W2_slice_0 → [B, seq, hidden] (partial)
GPU 1 computes: intermediate_1 @ W2_slice_1 → [B, seq, hidden] (partial)
...
All-reduce sums all 8 partials → final MLP output.
Communication per layer:
Forward: 1 all-reduce (after MLP) + 1 all-reduce (after attention) = 2 all-reduces.
Backward: same pattern reversed = 2 all-reduces.
Total per layer per step: 4 all-reduces of activation-size tensors.
Achieved: 15.1 PetaFLOPs across 512 V100s; 76% scaling efficiency.
The 76% scaling efficiency from 1 to 512 GPUs is what makes Megatron a practical approach — communication overhead does not catastrophically erode the gains.
Does it work? What breaks?
| Model | Parameters | Hardware | Throughput |
|---|---|---|---|
| GPT-2 8.3B | 8.3B | 512 V100 | 15.1 PFLOPs (76% scaling) |
| BERT 3.9B | 3.9B | 512 V100 | trained successfully |
State-of-the-art achieved at the time:
- WikiText-103 perplexity: 10.8 (vs 15.8 prior).
- LAMBADA accuracy: 66.5% (vs 63.2%).
- RACE (BERT 3.9B): 90.9% (vs 89.4%).
“We illustrate this approach by converging transformer based models up to 8.3 billion parameters using 512 GPUs. We sustain 15.1 PetaFLOPs across the entire application with 76% scaling efficiency.”
What breaks:
- Tensor parallel scales only within a fast-interconnect domain. All-reduces happen every layer; they need NVLink (300 GB/s). Across nodes (InfiniBand, ~25 GB/s), tensor parallel becomes a bottleneck. In practice, tensor parallel is used at degree 4 or 8 (within one node), then composed with pipeline parallel and data parallel for cross-node scaling.
- Activations are still fully replicated unless you also parallelize the sequence dimension. Sequence parallelism (a follow-up) splits activations along the sequence axis to save more memory.
- Pipeline parallel is needed for very large models. Tensor parallel up to 8 within a node, then split the model layer-wise across nodes (pipeline). The 3D parallelism of (data, tensor, pipeline) is the standard recipe today, established here.
- Implementation complexity. Megatron is a from-scratch PyTorch fork. The communication primitives, gradient accumulation, mixed precision, and checkpointing all have to be rewritten. Modern frameworks (Megatron-DeepSpeed, ColossalAI, NeMo) abstract this but the underlying complexity remains.
- Vocabulary parallel is non-trivial. Cross-entropy loss with a partitioned softmax requires careful handling of the log-sum-exp normalization.
So what?
For a practitioner training models past the single-GPU memory limit:
- Tensor parallelism is the right first axis to add. Within a node (4 or 8 GPUs with NVLink), tensor parallelism gives near-linear memory scaling at modest communication cost.
- Combine with ZeRO or FSDP for the data axis. Tensor parallel on its own does not help if the data parallel size is already maxing out memory through optimizer states. ZeRO-1 + tensor parallel is a standard combination.
- Pre-norm is mandatory at scale. Megatron showed it; every modern transformer follows. If you ever train post-norm at 1B+ params and see exploding gradients, this is why.
- For models past ~50B, add pipeline parallel. Tensor parallel beyond degree 8 is communication-bound; pipeline parallel across nodes lets you scale horizontally. The 3D recipe (Megatron-DeepSpeed, Colossal-AI) is the production standard.
- For deployment, use TensorRT-LLM or similar. The Megatron training-side tensor parallelism informed the inference-side analogues (vLLM, TensorRT-LLM all support tensor-parallel inference for serving multi-GPU models).
For the L5 systems interview question “explain how to train a 100B model”: Megatron’s tensor parallelism is the answer to “how do individual layers get parallelized.” Then layer in pipeline parallel for cross-node scaling, then layer in ZeRO for the data axis. The triple is the canonical answer.
Connections
- model-parallel — Megatron is the seminal tensor-parallel paper
- tensor-parallel — coined here in current form
- distributed-training — required reading
- memory-efficiency — primary motivation
- zero-memory-optimizations-trillion-parameter-models — composes orthogonally; ZeRO handles data axis, Megatron handles tensor axis
- attention-is-all-you-need — the architecture being parallelized
- microsoft-research — and NVIDIA Research collaboration
Citation
Shoeybi, M., Patwary, M., Puri, R., LeGresley, P., Casper, J., & Catanzaro, B. (2019). Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism. arXiv preprint. https://arxiv.org/abs/1909.08053