The Problem
Modern frontier models do not fit on a single GPU. A 70B parameter model in fp16 weights alone is 140 GB. Add gradients, optimizer states, and activations and you need 500 GB+. No single device has that much memory. Even where memory allows (a 7B model in 16 GB), training on one GPU takes weeks; you want to use 8 or 64 or 1024 GPUs to compress that to days.
You need to split the work across devices in a way that:
- Lets each device fit what it owns in its memory.
- Keeps the per-device compute high (avoid idle time).
- Minimizes the cross-device communication that is otherwise free in single-device training.
The Key Insight
There are three orthogonal axes you can split along, and the right system combines them:
- Data parallelism. Each device has the full model; different devices process different batches; gradients are aggregated. Easy and effective until the model itself does not fit.
- Model parallelism. The model is split across devices. Each device holds a slice of every layer (tensor parallel) or a subset of layers (pipeline parallel). Required when the model does not fit on one device.
- State sharding. Even with data parallel, each device replicating the optimizer states is wasteful. Shard the states across the data-parallel group; each parameter has one canonical owner. (ZeRO / FSDP.)
Mechanism in Plain English
The taxonomy:
- Data parallel (DP/DDP). N devices, each with a full model copy, each processing batch_size/N examples. After backward, all-reduce gradients across all devices. Synchronize. Apply optimizer step locally on each device. Used when model fits.
- Tensor parallel (TP). Within a single layer’s matmul, split the matrix across devices. Each device computes a slice of the output. Communication: all-reduce per layer. Used when individual layers are too big.
- Pipeline parallel (PP). Split layers across devices: GPU 0 holds layers 1-8, GPU 1 holds layers 9-16, etc. Pipeline micro-batches through. Communication: activations passed between adjacent stages. Bubbles in the schedule waste compute. Used when full model does not fit even with TP.
- Sharded data parallel (ZeRO / FSDP). Like DP but with optimizer states (and optionally gradients and parameters) partitioned across the DP group. Memory savings up to N-fold. Communication: similar to DP plus parameter all-gathers (in ZeRO-3 / FULL_SHARD).
The 3D parallelism: combine DP * TP * PP. For a frontier model on 1024 GPUs, you might have TP=8 (within node), PP=8 (across 8 nodes), DP=16 (across 16 such groups). 8 * 8 * 16 = 1024.
ASCII Diagram
3D Parallelism: TP within node, PP across nodes, DP across replicas
DP rank 0 DP rank 1 ...
+-------------------------+ +-------------------------+
node 0 | PP stage 0 (layers 1-8) | PP stage 0 (layers 1-8) |
| TP within node | TP within node |
| GPUs 0-7 | GPUs 8-15 |
+-------------------------+ +-------------------------+
| |
v v
+-------------------------+ +-------------------------+
node 1 | PP stage 1 (layers 9-16) | PP stage 1 (layers 9-16)|
| TP within node | TP within node |
| GPUs 16-23 | GPUs 24-31 |
+-------------------------+ +-------------------------+
| |
v v
... ...
Tensor parallel uses fast interconnect (NVLink) within a node. Pipeline parallel passes activations across slower interconnect (InfiniBand). Data parallel synchronizes gradients via all-reduce; can be cross-node.
What’s Clever
The reason 3D parallelism works is that each axis serves a different bottleneck:
- DP / FSDP scales the data axis (more examples per step) but is communication-bounded by the gradient all-reduce.
- TP scales the layer-width axis (bigger matrices) but communication is per-layer; can only scale within fast interconnect.
- PP scales the depth axis (more layers) but pipeline bubbles waste compute proportional to PP degree.
Combining them lets you use each axis where it shines: TP within fast NVLink, PP across slower IB, DP for the outermost replication.
Concrete Walkthrough
Train a 175B GPT-3 model on a 1024-GPU cluster (8 nodes of 128 GPUs each is unusual; assume 128 nodes of 8 GPUs):
Per-GPU memory budget: 80 GB (A100).
3D parallelism choice:
TP = 8 (within node)
PP = 16 (across nodes)
DP = 1024 / (8 * 16) = 8
Per GPU:
Parameter memory: 175e9 * 2 bytes / (TP * PP) = 350 GB / 128 = 2.7 GB
Gradient: 2.7 GB
Optimizer (Adam, fp32): 175e9 * 12 / 128 = 16.4 GB
Activations (per pipeline stage): ~4 GB
All-gather buffers: ~2 GB
TOTAL: ~28 GB. Fits in 80 GB with room for batch size scaling.
Communication patterns:
TP all-reduce: ~ activation_size, every layer, NVLink (~300 GB/s).
PP send/recv: ~ activation_size, between adjacent stages, IB (~25 GB/s).
DP all-reduce: ~ gradient_size, once per step, IB (~25 GB/s).
Effective batch size: per-GPU batch * DP = 4 * 8 = 32 (small).
With gradient accumulation 8x: effective 256.
Key Sources
-
zero-memory-optimizations-trillion-parameter-models — ZeRO data-parallel state sharding
-
megatron-lm-training-multi-billion-parameter-language-models — tensor parallel
-
pytorch-fsdp-fully-sharded-data-parallel — PyTorch-native ZeRO-3
Related Concepts
- data-parallel — the simplest axis
- model-parallel — required when model does not fit
- tensor-parallel — intra-layer model parallelism
- memory-efficiency — the dominant constraint
- mixed-precision-training — orthogonal lever; reduces memory by 2x
- inference-efficiency — closely related but for serving
Open Questions
- Pipeline bubble reduction. GPipe, PipeDream, 1F1B, Megatron’s interleaved schedule all attempt this. Tradeoffs between memory and bubble are not fully resolved.
- Cross-cluster training. Beyond one cluster (multiple datacenters), latency makes synchronous SGD impossible. Asynchronous methods (DiLoCo, federated) trade convergence speed for tolerance to slow comms.
- Heterogeneous hardware. Most distributed training assumes homogeneous nodes. Mixing A100s, H100s, B200s in one job is an open systems problem.