Lab 08: ZeRO / FSDP Memory Sharding
Annotated code reading lab. Running code is optional.
Distributed Training
Distributed training scales beyond one device by partitioning data, model state, or computation across ranks. The key questions are what is replicated, what is sharded, which collective runs on the critical path, and how optimizer semantics stay consistent.
Read code to understand the concept
What training focuses on
Core mechanism
- DDP replicates parameters, gradients and optimizer states on every rank.
- ZeRO-1 shards optimizer states; ZeRO-2 also shards gradients; ZeRO-3 also shards parameters.
- FSDP-style execution commonly gathers parameter shards when a wrapped unit needs them, but exact timing, prefetch and overlap depend on PyTorch configuration and version.
- ReduceScatter lets ranks aggregate gradients while keeping only their shard.
Annotated starter links
These files are reading material first. If you later decide to run them, treat the run as optional validation rather than the main learning path.
Starter Preview
Excerpt from code/lab-08-zero-fsdp/zero_memory_accounting.py. This preview explains the key idea; the linked starter file is the source of truth.
replicated = param_gb + grad_gb + adam_gb
zero1 = param_gb + grad_gb + adam_gb / dp
zero2 = param_gb + grad_gb / dp + adam_gb / dp
zero3 = param_gb / dp + grad_gb / dp + adam_gb / dp
# ZeRO-3/FSDP-style execution commonly needs parameter all-gather around computation.
# Gradients are commonly reduce-scattered so each rank keeps only its shard.Key code blocks
replicated- DDP-like baseline: every rank owns the full training state.
zero1- Only optimizer states are partitioned.
zero2- Optimizer states and gradients are partitioned.
zero3- Parameters, gradients and optimizer states are partitioned.
all-gather / reduce-scatter comments- Connect memory saving to collective communication cost.
How to read this code
- The formula describes persistent state, not all temporary peaks.
- Higher ZeRO stage saves memory but increases communication/scheduling complexity.
- Checkpoint format and resume flow become part of the system design.
What this code does not mean
- “ZeRO-3 memory is exactly divided by GPU count.” Gather peaks and buffers still exist.
- “FSDP is only about speed.” Its primary motivation is fitting larger models by sharding state.
How to say it out loud
ZeRO/FSDP-style methods start from the DDP memory problem: every GPU keeps full parameters, gradients and optimizer states. ZeRO stages shard progressively more of this state, while PyTorch FSDP is a related sharded data-parallel runtime rather than the identical DeepSpeed ZeRO API. The cost is that forward/backward commonly need collectives such as all-gather and reduce-scatter, so the design trades communication for memory.
Additional intuition
- DeepSpeed ZeRO docs give the clean first layer: optimizer states, gradients and parameters are progressively partitioned by ZeRO stage. Official: DeepSpeed ZeRO documentation
- PyTorch FSDP docs are useful for the communication angle because the process group is used for all-gather and reduce-scatter collectives. Official: PyTorch FSDP documentation
- Keep the intuition conservative: sharding lowers persistent state, but forward/backward can still create gather buffers and communication peaks. Paper: ZeRO
