InfraLens

A clear starting point for learning AI infrastructure.

Overview

Lab 08: ZeRO / FSDP Memory Sharding

Annotated code reading lab. Running code is optional.

Concept Goal

Read code to understand the concept

What training focuses on

Mental Model

Core mechanism

  • DDP replicates parameters, gradients and optimizer states on every rank.
  • ZeRO-1 shards optimizer states; ZeRO-2 also shards gradients; ZeRO-3 also shards parameters.
  • FSDP-style execution commonly gathers parameter shards when a wrapped unit needs them, but exact timing, prefetch and overlap depend on PyTorch configuration and version.
  • ReduceScatter lets ranks aggregate gradients while keeping only their shard.
Starter files

Annotated starter links

These files are reading material first. If you later decide to run them, treat the run as optional validation rather than the main learning path.

Annotated Code Preview

Starter Preview

Excerpt from code/lab-08-zero-fsdp/zero_memory_accounting.py. This preview explains the key idea; the linked starter file is the source of truth.

replicated = param_gb + grad_gb + adam_gb
zero1 = param_gb + grad_gb + adam_gb / dp
zero2 = param_gb + grad_gb / dp + adam_gb / dp
zero3 = param_gb / dp + grad_gb / dp + adam_gb / dp

# ZeRO-3/FSDP-style execution commonly needs parameter all-gather around computation.
# Gradients are commonly reduce-scattered so each rank keeps only its shard.
Line-by-line Explanation

Key code blocks

replicated
DDP-like baseline: every rank owns the full training state.
zero1
Only optimizer states are partitioned.
zero2
Optimizer states and gradients are partitioned.
zero3
Parameters, gradients and optimizer states are partitioned.
all-gather / reduce-scatter comments
Connect memory saving to collective communication cost.
What to Notice

How to read this code

  • The formula describes persistent state, not all temporary peaks.
  • Higher ZeRO stage saves memory but increases communication/scheduling complexity.
  • Checkpoint format and resume flow become part of the system design.
Common Misunderstandings

What this code does not mean

  • “ZeRO-3 memory is exactly divided by GPU count.” Gather peaks and buffers still exist.
  • “FSDP is only about speed.” Its primary motivation is fitting larger models by sharding state.
Interview Explanation

How to say it out loud

ZeRO/FSDP-style methods start from the DDP memory problem: every GPU keeps full parameters, gradients and optimizer states. ZeRO stages shard progressively more of this state, while PyTorch FSDP is a related sharded data-parallel runtime rather than the identical DeepSpeed ZeRO API. The cost is that forward/backward commonly need collectives such as all-gather and reduce-scatter, so the design trades communication for memory.

External intuition notes

Additional intuition

  • DeepSpeed ZeRO docs give the clean first layer: optimizer states, gradients and parameters are progressively partitioned by ZeRO stage. Official: DeepSpeed ZeRO documentation
  • PyTorch FSDP docs are useful for the communication angle because the process group is used for all-gather and reduce-scatter collectives. Official: PyTorch FSDP documentation
  • Keep the intuition conservative: sharding lowers persistent state, but forward/backward can still create gather buffers and communication peaks. Paper: ZeRO
Further Reading

Official, paper and practical references