The Problem
Training neural networks in single-precision (fp32) doubles the memory cost of every tensor compared to half-precision (fp16 or bf16) and halves throughput on Tensor Cores. But naively training in fp16 fails: gradients underflow (denormalized fp16 values flush to zero), the loss diverges, the model never converges. You want fp16’s speed and memory benefits without fp16’s numerical fragility.
The Key Insight
Use different precisions for different things. Forward and backward passes run in fp16 (fast Tensor Core matmuls, cheap memory). The optimizer state stays in fp32 (preserves precision in moment estimates over many steps). A “loss scale” multiplier shifts gradients up before backward and down after, dodging fp16’s underflow region.
bf16 (brain float, 16 bits with fp32-style exponent) sidesteps the underflow problem entirely at a small cost in mantissa precision. On hardware that supports it (A100, H100, TPUs), bf16 mixed precision is even simpler.
Mechanism in Plain English
fp16 mixed precision (NVIDIA AMP, original):
- Master copy of weights in fp32. This is the “ground truth.”
- Before forward: cast weights to fp16. Compute forward in fp16.
- Multiply loss by a large scale factor S (e.g., 1024 or 65536) to push gradients above fp16’s underflow region.
- Backward in fp16, producing scaled fp16 gradients.
- Convert gradients to fp32 and divide by S to undo the scaling.
- Optimizer step in fp32 using the unscaled gradients.
- Update fp32 master weights.
bf16 mixed precision (modern):
- Master weights in fp32.
- Forward and backward in bf16 (no scaling needed; bf16 has fp32-equivalent dynamic range).
- Optimizer step in fp32 using bf16-cast-to-fp32 gradients.
bf16 sacrifices some mantissa precision (7 mantissa bits vs fp16’s 10) but its exponent range matches fp32, so underflow is virtually eliminated. This made bf16 the modern default starting around 2020.
ASCII Diagram
fp16 mixed precision:
fp32 master weights
|
cast to fp16 (every step)
|
v
fp16 weights ---> [Forward pass in fp16] ---> fp16 activations
|
v
fp16 gradients <--- [Backward pass in fp16] --- fp16 loss * S
(scale up to avoid underflow)
|
upcast to fp32, divide by S
|
v
fp32 gradients ---> [Adam optimizer in fp32] ---> updated fp32 master weights
What’s Clever
The asymmetry: forward and backward have many parallel fp16-friendly matmuls (Tensor Core acceleration). The optimizer step is a small piece of work but it accumulates over many iterations; precision matters there. By doing the heavy compute in fp16 and the precision-sensitive accumulation in fp32, you get speed where it counts and accuracy where it counts.
Loss scaling is a clever workaround for fp16’s narrow dynamic range. fp16 can represent numbers from 6e-5 to 65504. Gradient values are typically in [1e-8, 1e-1]; the small ones underflow. Multiply by 65536 and they fit. Divide by 65536 after upcast and they’re back to their true values. No information lost (until you divide). This is the main reason original AMP works.
Concrete Walkthrough
GPT-3 175B training memory accounting:
Pure fp32 (no mixed precision):
Parameters: 175B * 4 bytes = 700 GB
Gradients: 700 GB
Adam m: 700 GB
Adam v: 700 GB
TOTAL state: 2800 GB on every replica.
fp16 mixed precision:
fp16 params: 175B * 2 = 350 GB
fp32 master params: 700 GB
fp16 gradients: 350 GB
fp32 Adam m: 700 GB
fp32 Adam v: 700 GB
TOTAL state: 2800 GB. Wait, the same?
Actually, per-tensor savings come from activations (fp16 instead of fp32),
which is the dominant cost during training. fp16 activations are half the memory.
Per-step memory savings: ~50%.
bf16 mixed precision:
Same accounting; bf16 activations are 50% smaller.
Plus: no need for loss scaling (and the buggy edge cases that come with it).
Throughput on A100 Tensor Core:
fp32: 19 TFLOPs/s
fp16/bf16: 312 TFLOPs/s (16x faster on TC, 2x faster end-to-end)
The 2x end-to-end speedup is the headline; the memory savings on activations
let you fit larger batches, which improves further throughput.
What’s Clever (continued)
The hidden subtlety: fp16’s effective precision near 1.0 is about 0.001 (relative). Adam’s moment estimates accumulate over many steps with fractional updates; they need more precision than fp16. Keeping m and v in fp32 is non-negotiable. This is why pure fp16 training fails: the optimizer state corrupts after a few thousand steps.
bf16 mostly fixes this: the exponent range is enough that gradients don’t underflow, but the mantissa precision is similar to fp16. Why does bf16 work where fp16 doesn’t? Because bf16’s wider dynamic range means you don’t need loss scaling at all, and the actual magnitudes the optimizer sees are correct (just rounded). Adam handles the rounding gracefully.
Key Sources
-
zero-memory-optimizations-trillion-parameter-models — ZeRO assumes mixed precision; the optimizer-state savings are largest in fp32 components
-
pytorch-fsdp-fully-sharded-data-parallel — FSDP exposes detailed mixed-precision policies
Related Concepts
- distributed-training — mixed precision is universal in distributed training today
- memory-efficiency — primary motivation
- quantization — related but distinct (quantization is for inference, mixed precision is for training)
Open Questions
- fp8 training. H100 supports fp8 Tensor Cores. Training in fp8 (e.g., NVIDIA Transformer Engine) is emerging; the precision tradeoff is more delicate than fp16/bf16.
- Per-tensor mixed precision. Different tensors may benefit from different precisions. Current frameworks expose policies; future work may auto-tune per-tensor.
- Loss scaling automation. Dynamic loss scaling adapts S during training. Robust in practice; theory is heuristic.