Concepts: distributed-training | data-parallel | memory-efficiency Builds on: zero-memory-optimizations-trillion-parameter-models

ZeRO-3 was the algorithm; DeepSpeed was the first implementation; FSDP is what happens when the same idea gets a from-scratch redesign by the PyTorch core team, co-designed with the dispatcher, the autograd engine, and CUDA’s caching allocator. The contribution is less a new algorithm than a careful set of engineering decisions that turn ZeRO-3 from “an extra library you bolt on” into “a native PyTorch wrapper that just works.”

The core idea

The analogy: ZeRO-3 is a brilliant building plan that solves a real problem (memory at scale). DeepSpeed is the first contractor who built it — but the plumbing crosses the electrical, the HVAC ducts go through load-bearing walls, and you have to coordinate four separate subcontractors who do not talk to each other. FSDP is the second team, with the architect’s revised drawings, where every system is integrated up front.

The algorithm is the same: every parameter has one canonical owner; ranks fetch parameters before they need them and discard after; gradients are reduce-scattered to their owner; optimizer step happens locally on each rank’s owned slice.

“We introduce PyTorch Fully Sharded Data Parallel (FSDP) as an industry-grade solution for large model training. FSDP has been closely co-designed with several key PyTorch core components including Tensor implementation, dispatcher system, and CUDA memory caching allocator.”

The engineering wins:

  1. Native autograd integration. No custom backward passes needed. FSDP wraps modules and uses PyTorch hooks to trigger all-gather (forward) and reduce-scatter (backward).
  2. Coordinated memory allocation. The CUDA caching allocator does not handle FSDP’s “burst” allocation pattern (the all-gathered parameter buffer comes and goes per layer) well. FSDP works with the allocator to use a RecordStream mechanism that prevents fragmentation.
  3. Flexible sharding strategies. Beyond ZeRO-3’s full sharding, FSDP supports SHARD_GRAD_OP (gradient + optimizer sharding only, like ZeRO-2) and HYBRID_SHARD (full sharding within a node, replication across nodes — useful when cross-node bandwidth is much lower than intra-node).
  4. Mixed precision policies. Controls per-tensor whether to keep parameters in fp16/bf16/fp32 during all-gather, what dtype the gradient reduction uses, and what the master parameters are stored in.

What’s clever — find the instinct

The non-obvious move is the HYBRID_SHARD mode. In a multi-node training setup, the cross-node bandwidth (InfiniBand, ~25 GB/s) is much smaller than the intra-node NVLink (~300 GB/s for A100). Full ZeRO-3 / FULL_SHARD does an all-gather across all GPUs, including across nodes, every layer. That bottlenecks on cross-node bandwidth.

HYBRID_SHARD: shard within each node (8 GPUs all-gather over fast NVLink), replicate across nodes (each node has the same shards). The cross-node communication becomes once-per-step gradient all-reduce (the same as standard data parallel), not once-per-layer parameter all-gather.

Memory cost: each node holds a full copy of the parameters (vs ZeRO-3’s full partitioning), so memory is cut by intra_node_size (typically 8) instead of world_size (typically 64+). For models that fit with this looser cut, HYBRID_SHARD is dramatically faster than FULL_SHARD on cross-node clusters.

“FSDP natively incorporates a range of techniques and settings to optimize resource utilization across a variety of hardware configurations.”

The second clever move is rate limiter. ZeRO-3’s pattern is: all-gather params for layer N+1 while layer N is computing. The prefetch keeps the GPU busy. But with eager prefetching, the all-gather buffers can pile up, exhausting memory. FSDP introduces a ShardingStrategy with explicit prefetch control: prefetch one layer ahead by default, allow more on hardware with more memory bandwidth headroom.

The third clever move is flat parameter representation. Each FSDP module flattens its parameters into one contiguous tensor and shards that tensor. This makes the all-gather one large communication operation per module rather than many small ones (which would each have launch overhead). Modern collectives (NCCL) are vastly more efficient on one large all-gather than many small ones.

Walkthrough: 13B LLaMA on 64 A100s

Setup: 13B parameter LLaMA-2 architecture.
       64 A100 80GB GPUs, 8 nodes of 8 GPUs each.
       Mixed precision: bf16 forward/backward, fp32 master + Adam states.

Memory accounting (per GPU):

FULL_SHARD (ZeRO-3 equivalent):
  Sharded fp32 master:    13e9 * 4 / 64 = 0.81 GB
  Sharded Adam m:         0.81 GB
  Sharded Adam v:         0.81 GB
  Sharded gradients (bf16): 13e9 * 2 / 64 = 0.41 GB
  Sharded params (bf16):  0.41 GB
  All-gather temp buffer: ~size of largest layer ≈ 1 GB
  Activations (recompute on): ~3 GB
  TOTAL: ~7.3 GB per GPU.

  Communication per layer (fwd+bwd):
    1 all-gather param (across all 64 GPUs, cross-node bandwidth)
    1 reduce-scatter grad (across all 64 GPUs)
  Cross-node bandwidth limited: ~70% MFU.

HYBRID_SHARD:
  Sharded across 8 GPUs/node, replicated across nodes:
  Sharded fp32 master:    13e9 * 4 / 8 = 6.5 GB
  Sharded Adam m, v:      6.5 + 6.5 GB
  Sharded gradients:      13e9 * 2 / 8 = 3.25 GB
  Sharded params (bf16):  3.25 GB
  All-gather temp buffer: ~size of largest layer ≈ 1 GB
  Activations: ~3 GB
  TOTAL: ~30 GB per GPU. Fits in 80 GB.

  Communication per layer:
    1 all-gather (within node, NVLink fast)
    1 reduce-scatter (within node, NVLink fast)
  Per-step cross-node:
    1 all-reduce of gradients (cross-node, but only once per step, not per layer)
  Result: ~85% MFU on the same cluster.

The win from HYBRID_SHARD: same memory pressure dropped 8x (still fits the model), much higher throughput because per-layer comm stays on NVLink. This is the kind of optimization that ZeRO-3 in DeepSpeed historically required custom configuration for; FSDP exposes it as a single enum.

Does it work? What breaks?

ModelHardwareStrategyMemory/GPUThroughput
175B model128 A100 80GBFULL_SHARD~30 GB173 TFLOPs/s
1T model512 A100 80GBFULL_SHARD + activation ckpt~75 GB84 TFLOPs/s

Headline result: near-linear scalability up to the largest open-source models that fit. FSDP is a drop-in replacement for DistributedDataParallel; flipping a few config flags moves you from DDP to ZeRO-2 to ZeRO-3 to HYBRID_SHARD.

“FSDP is capable of achieving comparable performance to Distributed Data Parallel while providing support for significantly larger models with near-linear scalability in terms of TFLOPS.”

What breaks:

  • Sub-module wrapping policy is critical. If you wrap too coarse (one big FSDP unit), the all-gather buffer is too large; not enough overlap. If you wrap too fine (every layer), per-call overhead dominates. Rule of thumb: wrap at the transformer-block level. Bad wrapping is the #1 cause of FSDP underperformance.
  • Mixed precision interactions. The default policy keeps optimizer states in fp32 even when forward/backward is bf16. If you push everything to bf16 to save memory, you lose numerical stability. The paper documents pitfalls.
  • Activation checkpointing required for large models. FSDP shards parameters but not activations. For 70B+ models, you also need torch.utils.checkpoint or selective activation checkpointing.
  • Saving / loading checkpoints is non-trivial. Each rank holds different shards; combining into a unified checkpoint requires coordination. FSDP exposes state_dict_type controls (full, sharded, local) that the user has to choose correctly.
  • Custom autograd functions. If you define your own autograd.Function, FSDP’s hooks may not interact with it correctly; need to test.

So what?

For a practitioner training large models in PyTorch:

  1. Use FSDP, not DDP, the moment your model crosses ~1B parameters. Even with the simpler ZeRO-2 equivalent (SHARD_GRAD_OP), the memory savings let you fit larger batches and avoid pipeline parallel.
  2. HYBRID_SHARD is the pragmatic default for multi-node training. Unless your model is so large it does not fit per-node even after sharding, you want the cross-node bandwidth savings.
  3. Wrap at the transformer-block level. A common idiom: FSDP(model, auto_wrap_policy=transformer_auto_wrap_policy(...)) with the block class as the wrap target.
  4. Pair with activation checkpointing for very large models. FSDP’s parameter savings are necessary but not sufficient; activations dominate memory at long sequence lengths.
  5. Mixed precision: bf16 for compute, fp32 for optimizer. This is the FSDP default and the right default. Going further (fp16 optimizer) is fragile.
  6. Save sharded checkpoints; merge offline if needed. Saving full state dicts on rank 0 sequentially is slow at scale; sharded saves are concurrent.

For an L5 interview answering “what is the difference between DeepSpeed and FSDP”: both implement ZeRO-3. DeepSpeed has a more mature feature set (offload, MoE, infinity), FSDP has tighter PyTorch integration and is easier to combine with other PyTorch features. Production teams pick based on their stack: HuggingFace tends to FSDP, Microsoft tends to DeepSpeed.

Connections

Citation

arXiv:2304.11277

Zhao, Y., Gu, A., Varma, R., et al. (2023). PyTorch FSDP: Experiences on Scaling Fully Sharded Data Parallel. VLDB 2023. https://arxiv.org/abs/2304.11277