InfraLens

A clear starting point for learning AI infrastructure.

Overview

Lab 02: ZeRO / FSDP State Sharding

Annotated code reading lab. Running code is optional.

Related handbook section

ZeRO / FSDP State Sharding

This lab maps directly to the handbook section. Read the related handbook section first, then use the lab page and starter file to connect the concept to concrete variables, shapes, APIs, and interview-ready explanations.

Concept Goal

ZeRO / FSDP State Sharding

Separate which training states are sharded: optimizer, gradients, parameters.

Mental Model

Mechanism to keep in mind

  • `params` may be all-gathered just in time.
  • `grads` may be reduce-scattered after backward.
  • `optimizer_state` stays partitioned.
Approximation

What the memory formula excludes

ZeRO-style state estimate

ZeRO-3 ~= params / DP + gradients / DP + optimizer states / DP

  • DP is data-parallel degree.
  • This is persistent training state only.
  • Activation memory, temporary gathers, fragmentation and workspaces can still dominate peak memory.
Annotated Code Preview

Starter preview

Excerpt from code/lab-02-zero-fsdp-state-sharding/zero_fsdp_states.py. The linked starter file is the source of truth.

Open starter file
# ZeRO / FSDP State Sharding
# Annotated reading material. Running this file is optional.
# Source-of-truth focus: Separate which training states are sharded: optimizer, gradients, parameters.

states = {"params": "sharded", "grads": "sharded", "optimizer": "sharded"}
full_params_for_forward = "all_gather(params_shard)"
local_grads = "backward(full_params_for_forward)"
grad_shard = "reduce_scatter(local_grads)"

# What to explain while reading:
# - params may be all-gathered just in time.
# - grads may be reduce-scattered after backward.
# - optimizer_state stays partitioned.
#
# Common traps:
# - FSDP is not pipeline parallelism.
# - Sharding state does not remove communication.
Line-by-line Explanation

What each block is doing

Setup / contract
`params` may be all-gathered just in time.
Main transition
`grads` may be reduce-scattered after backward.
Interview hook
`optimizer_state` stays partitioned.
What to Notice

Reading checkpoints

  • FSDP is still data-parallel training semantics.
  • It saves memory by adding collectives.
  • Peak memory includes temporary gathers.
Common Misunderstandings

What this lab prevents

  • FSDP is not pipeline parallelism.
  • Sharding state does not remove communication.
Interview Explanation

How to say it out loud

Separate which training states are sharded: optimizer, gradients, parameters. Then explain the code by naming the state being transformed, the axis or shape that matters, and the tradeoff that would appear in a real system.

External intuition notes

Additional intuition

  • Use official docs and papers for API behavior and factual claims; use blogs only to improve the mental picture.
  • If support matrices, performance behavior or backend choices are version-sensitive, check current docs before repeating them.
  • A strong interview answer names the state object, the shape or axis it changes, and the tradeoff it creates.
Further Reading

Official, paper and practical references