The Problem
You have a model that fits on one GPU and a training loop that takes too long on one GPU. You want to use 8 or 64 GPUs to go 8 or 64 times faster. The simplest way to do this: process different examples on different GPUs in parallel.
The Key Insight
Replicate the model on every GPU. Split the batch across GPUs. Each GPU processes its sub-batch independently and computes local gradients. Aggregate the gradients across GPUs (typically via all-reduce). Each GPU now has the same global gradient. Each GPU runs the same optimizer step on its replica. After the step, all replicas are still identical.
This is bulk-synchronous parallelism: all GPUs march in lockstep, gradient aggregation is the synchronization point.
Mechanism in Plain English
- Initialize the model. Make N copies, one per GPU.
- Split the global batch into N sub-batches.
- Each GPU runs forward + backward on its sub-batch, producing local gradients.
- All-reduce gradients across all GPUs. Each GPU now has the average gradient.
- Each GPU runs the optimizer step. Because all start identical and all see the same gradient, all end identical.
- Loop.
In PyTorch, this is DistributedDataParallel (DDP). In TensorFlow, MirroredStrategy. In JAX, pjit with replicated parameters.
ASCII Diagram
8-GPU Data Parallel:
Global batch (size 256)
|
split into 8 sub-batches of 32
|
+----+----+----+----+----+----+----+----+
| | | | | | | | |
v v v v v v v v
GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7
Full Full Full Full Full Full Full Full
Model Model ...
| | | | | | | |
fwd fwd fwd fwd fwd fwd fwd fwd
bwd bwd bwd bwd bwd bwd bwd bwd
grad grad grad grad grad grad grad grad
|____|____|____|____|____|____|____|
ALL-REDUCE GRADIENTS
| | | | | | | |
step step step step step step step step
| | | | | | | |
v v v v v v v v
(all 8 models still identical, on to next step)
What’s Clever
Data parallel is a free lunch up to a point: it costs only one all-reduce per step in communication, and the math is identical to single-GPU SGD on the larger batch. Throughput scales near-linearly with GPUs (until communication or batch size effects intervene).
The mathematical equivalence is what makes it so popular: a researcher’s single-GPU training script can be adapted with a few lines and produce identical results (up to numerical noise from the gradient aggregation order).
Limitations
- Per-GPU memory is unchanged. Each GPU holds the full model. Once the model + activations + optimizer state exceed per-GPU memory, you cannot scale by adding more GPUs; you need ZeRO/FSDP or model parallel.
- Effective batch size grows with GPUs. With 64 GPUs at micro-batch 4, the effective batch is 256. At 1024 GPUs, 4096. Very large batches degrade convergence: each step covers many examples, so you take fewer steps; learning-rate scaling helps but the math is fragile past ~10K batch size.
- Communication is bandwidth-bound. All-reduce moves gradients proportional to model size. For large models on slow interconnect, the all-reduce becomes the bottleneck.
- Stragglers. If one GPU is slower (thermal, scheduling), the all-reduce waits for it. Tail latency hurts.
Key Sources
- zero-memory-optimizations-trillion-parameter-models — ZeRO is data-parallel with state sharding
- pytorch-fsdp-fully-sharded-data-parallel — FSDP is the modern PyTorch implementation
Related Concepts
- distributed-training — broader category
- model-parallel — the alternative axis when model does not fit
- memory-efficiency — the constraint that pushes beyond pure DP
- optimization — large-batch training requires careful learning rate scaling
Open Questions
- Asynchronous DP. Removes the synchronization barrier. Workers update independently. Convergence is fragile but no straggler problem. Less common since DDP became fast.
- Local SGD / federated. Workers do K local steps before averaging, reducing communication frequency. Trade-off: longer effective batch, possibly worse convergence.