AI Infra Deep Dive

Distributed Training

A deep-dive branch of the AI Infra umbrella, focused on DDP, collectives, ZeRO/FSDP, tensor/pipeline/sequence/expert parallelism, and topology-aware design.

How to use this page

Concept first, code reading second

Read each mechanism as problem, synchronization pattern, state placement, communication cost, and common misunderstanding. Labs are optional annotated reading material.

Reading path

Read the related handbook section first, then use the lab page and starter file to connect the concept to concrete variables, shapes, APIs, and interview-ready explanations.

overview

Overview: Why Distributed Training Is Needed

Distributed training scales training scale beyond single-device memory and compute limits by partitioning inputs, model state, and computation across multiple accelerator ranks.

#

What problem does this solve?

Single-GPU training runs into memory, compute-time, and input-throughput limits. Fitting large models requires dividing model states, activations, and batches across multiple nodes and chips.

Core mechanism

Distributed training splits training work using data parallel, model parallel (tensor, pipeline, sequence), and expert parallel configurations. These strategies communicate via collective operations to synchronize weights, activations, and gradients across process groups.

What to say in an interview

Explain that distributed scaling is a multi-dimensional tradeoff between memory consumption, compute efficiency, and network communication bandwidth/latency. We choose the parallel strategy based on which resource is the dominant bottleneck.

Common misunderstanding

Calling every multi-GPU method data parallel. In practice, data parallel only shards inputs while replicating the model. Model parallel strategies shard the model state itself, introducing different communication patterns and topologies.

ddp

Data Parallel / DDP

Data Parallel (DDP) replicates the model across ranks while sharding input batches, aggregating gradients at each step using all-reduce collectives.

#

What problem does this solve?

Training on a single data stream limits sample throughput. Data parallel replicates the entire model across multiple GPUs and feeds each replica a different data batch, allowing linear scaling of training batch size.

Core mechanism

In PyTorch DDP, each GPU process processes its own batch in parallel. During the backward pass, DDP registers autograd hooks that trigger gradient synchronization (all-reduce) in background buckets, overlapping communication with backward pass computation.

Definition: rank / world size / process group

rank is one participant in a distributed job, world_size is the number of ranks, and a process group is the communicator over which collectives such as all-reduce run.

rank 0 data shard   rank 1 data shard   rank 2 data shard
      |                   |                   |
forward/backward    forward/backward    forward/backward
      |                   |                   |
    grad                grad                grad
      +------ gradient all-reduce -----------+
                          |
                same averaged gradients
                          |
                 local optimizer step
What to say in an interview

Describe gradient bucketing: DDP groups gradients in buckets (reverse order of execution) and initiates all-reduce as soon as a bucket is ready. This hides communication latency behind the remaining backpropagation math.

Common misunderstanding

Assuming DDP scales indefinitely. As GPU count increases, the effective global batch size scales up, which can hurt model convergence unless learning rates are scaled. Also, at high scale, the all-reduce step can become communication-bound.

allreduce

Gradient AllReduce

AllReduce synchronizes gradient state across ranks, ensuring weight updates remain identical without requiring a central coordinator.

#

What problem does this solve?

Data parallel replicas must update their weights using identical, averaged gradients. AllReduce aggregates local gradients from all ranks and distributes the globally averaged result back to all ranks.

Core mechanism

Standard collective implementations use ring or tree topologies. In a ring all-reduce, the tensor is sharded, and ranks pass chunks in a circle. Ranks execute N-1 steps of reduce-scatter (to sum values) and N-1 steps of all-gather (to broadcast the sums).

Formula: ring AllReduce traffic approximation 2 * (N - 1) / N * data
  • N: number of ranks.
  • data: tensor size being reduced.
  • The factor 2 comes from reduce-scatter plus all-gather phases in a ring-style all-reduce.
  • This is bandwidth-oriented and ignores latency, topology, chunking and implementation details.
What to say in an interview

State the bandwidth equation for ring all-reduce: 2 * (N - 1) / N * data. Highlight that the communication volume is independent of the number of nodes at high N, making it highly efficient for massive clusters.

Common misunderstanding

Thinking all-reduce broadcasts all gradients to a central master node. Master-worker aggregation creates a severe network bottleneck at the master. Production setups use peer-to-peer ring or tree collectives over NCCL.

zero-fsdp

ZeRO / FSDP

ZeRO and FSDP eliminate redundant memory by sharding parameters, gradients, and optimizer states across the data parallel group.

#

What problem does this solve?

DDP replicates weights, gradients, and optimizer states on every GPU, which limits the maximum model size to what fits on a single device. ZeRO shards these states to free up memory.

Core mechanism

ZeRO shards states: ZeRO-1 shards optimizer states, ZeRO-2 shards optimizer + gradients, and ZeRO-3 (FSDP) shards optimizer + gradients + parameters. FSDP gathers param shards dynamically during the forward pass and discards them after, reducing persistent weight memory.

Formula: ZeRO training-state memory approximation replicated = params + gradients + optimizer states
ZeRO-1 ~= params + gradients + optimizer states / DP
ZeRO-2 ~= params + gradients / DP + optimizer states / DP
ZeRO-3 ~= params / DP + gradients / DP + optimizer states / DP
  • DP: data-parallel degree.
  • This estimates persistent training states, not activations, temporary all-gather buffers, fragmentation or workspaces.
  • ZeRO-3/FSDP can still have runtime peak memory during parameter materialization.
Replicated training:
GPU0: params + grads + optimizer
GPU1: params + grads + optimizer

ZeRO/FSDP-style sharding:
GPU0: state shard 0
GPU1: state shard 1

forward may all-gather parameters
backward may reduce-scatter gradients

FSDP-style execution commonly gathers parameter shards when a module needs them and reduces/shards gradients during backward. Exact all-gather timing, prefetch behavior and overlap depend on PyTorch FSDP configuration and version.

What to say in an interview

Detail the memory math of ZeRO stages. Discuss how FSDP exchanges memory footprint for communication overhead: it replaces a single DDP gradient all-reduce with parameter all-gathers in forward/backward and a gradient reduce-scatter.

Common misunderstanding

Assuming ZeRO-3 has no peak memory issues. While persistent weights are sharded, a layer's parameters must be materialized in full on the GPU during its forward and backward computation, creating temporary memory peaks.

tp

Tensor Parallel

Tensor Parallel partitions matrix weights within a single layer across ranks, splitting wide matrix multiplications to distribute the computational load.

#

What problem does this solve?

When a single layer's parameter size exceeds a GPU's capacity, we must split the matrix multiplications within that layer across multiple GPUs.

Core mechanism

In Megatron-LM style TP, attention and MLP layers are split. Linear layers are partitioned columns-wise (first layer) and row-wise (second layer). This design ensures that columns-wise outputs feed row-wise layers, requiring only two communication steps per block.

What to say in an interview

Be ready to draw the column-row split diagram. Explain that column-parallel projection requires no immediate communication, but the subsequent row-parallel projection must run an all-reduce to sum the split outputs.

Common misunderstanding

Thinking TP can scale across slow network links. Because TP communicates multiple times per layer, it is extremely sensitive to latency. It must be restricted to GPUs within the same NVLink domain (typically 8 GPUs per node).

pp

Pipeline Parallel

Pipeline Parallel partitions model layers sequentially across nodes, using schedules like 1F1B to overlap stage computations and manage activation memory.

#

What problem does this solve?

When a model has too many layers to fit on a single node's NVLink domain, we partition the layers sequentially across multiple nodes.

Core mechanism

PP divides the model layers into sequential stages. To avoid GPipe's massive 'bubble' (idle GPUs waiting for activations), schedules like 1F1B (One Forward, One Backward) interleave forward and backward steps, keeping GPU activation memory stable.

ScheduleOrderingActivation and bubble implicationUse in an explanation
GPipeRuns forward micro-batches before backward micro-batches.Simple schedule, but stores more activations and exposes a larger bubble.Use it to introduce pipeline fill and drain.
1F1BInterleaves one forward and one backward after warmup.Bounds resident activations more tightly and improves steady-state utilization.Use it when explaining practical PP memory control.
What to say in an interview

Contrast GPipe and 1F1B. Explain that 1F1B limits activation memory to the number of micro-batches in the pipeline depth, whereas GPipe stores activations for the entire batch, leading to memory overhead.

Common misunderstanding

Assuming PP is communication-bound. In fact, PP only communicates boundary activations and gradients between adjacent nodes. It has low communication volume compared to TP, making it ideal for cross-node partitioning.

sp-cp

Sequence / Context Parallel

Sequence and Context Parallel shard the token sequence dimension across ranks, mitigating activation memory peaks during long-context training.

#

What problem does this solve?

Long-context training generates large activation tensors in attention layers, causing GPU out-of-memory (OOM) failures even when weights are sharded.

Core mechanism

Sequence/Context Parallel partitions the sequence dimension across GPUs. Instead of duplicating the sequence inside a TP group, SP shards the activations and uses all-gather or ring-attention to compute attention scores across split sequences.

What to say in an interview

Explain that Ulysses-style SP uses all-to-all collectives to transpose the sequence dimension to head dimension before attention, while Ring Attention passes KV blocks around a ring of GPUs, matching the attention causal mask.

Common misunderstanding

Thinking SP and TP are mutually exclusive. Ulysses-style sequence parallel is often combined with Tensor Parallelism to scale context length and model width concurrently within the same high-speed interconnect domain.

ep

Expert Parallel / MoE

Expert Parallel shards Mixture-of-Experts layers across ranks, utilizing all-to-all collectives to route tokens to matching expert locations.

#

What problem does this solve?

Mixture-of-Experts (MoE) scales model parameters without scaling compute costs by routing each token to a subset of specialized expert layers.

Core mechanism

A gating network routes tokens to top-k experts. Ranks run all-to-all communication to send tokens to the GPUs holding the chosen experts, execute the expert MLPs in parallel, and run all-to-all to return the processed tokens.

What to say in an interview

Discuss MoE serving bottlenecks: load balancing is critical because if all tokens route to the same expert, other GPUs sit idle. Discuss capacity factor limits and load-balancing auxiliary loss functions used to enforce routing parity.

Common misunderstanding

Assuming expert parallel is compute-heavy. Since each token only visits a few experts, the FLOPs per token remain low. However, the all-to-all communication creates massive network traffic, making it communication-bandwidth bound.

3d

3D Parallelism

3D Parallelism integrates TP, PP, and DP/ZeRO axes to scale models to thousands of GPUs across diverse cluster topologies.

#

What problem does this solve?

Training frontier LLMs with trillions of parameters requires combining multiple parallel axes to optimize memory and compute scaling across thousands of GPUs.

Core mechanism

3D Parallelism integrates Data Parallel (DP/ZeRO), Tensor Parallel (TP), and Pipeline Parallel (PP). TP operates within a single GPU node (NVLink), PP partitions layers across nodes, and DP/ZeRO scales throughput across the entire cluster.

AxisWhat it partitionsPreferred placementMain cost
TPMatrix computation inside layers.Fast intra-node GPU links such as NVLink.Frequent latency-sensitive collectives.
PPSequential groups of layers.Across balanced stages and acceptable inter-node links.Pipeline bubbles and activation transfer.
DP / ZeROSamples or sharded training state across replicas.Remaining replica groups across the cluster.Gradient or state synchronization and checkpointing.
What to say in an interview

Describe how to map the 3D grid to hardware topology: TP gets the fastest links (intra-node NVLink), PP gets intermediate connections (inter-node InfiniBand), and DP/ZeRO spans the remaining nodes to parallelize batch processing.

Common misunderstanding

Thinking you can configure 3D parallel indices arbitrarily. A poor layout (e.g. putting TP across different nodes over slow Ethernet) will saturate network links instantly, bringing GPU computing efficiency down to single digits.

topology

Topology and Communication

Topology-aware placement designs process groups to align communication frequency with hardware physical connection bandwidths.

#

What problem does this solve?

Collectives must adapt to physical hardware layout. A mismatch between communication patterns and network topology causes severe congestion and slows down training.

Core mechanism

Clusters have hierarchical connections: intra-node (NVLink, 900 GB/s), inter-node (InfiniBand/RoCE, 50 GB/s), and inter-rack. Collective libraries like NCCL dynamically build rings or trees to maximize utilization of available link bandwidth.

What to say in an interview

Discuss topology-aware rank placement. Ranks that communicate frequently (TP, Ulysses SP) must be placed on the same node. Ranks with intermediate traffic (PP, MoE EP) can span nodes, while DP covers the rest.

Common misunderstanding

Assuming standard TCP/IP Ethernet is sufficient for large-scale distributed training. Ordinary TCP has high CPU overhead and latency. High-performance clusters use RDMA (Remote Direct Memory Access) over InfiniBand or RoCE to bypass CPU stacks.

interview

Interview Questions

Production interview scenarios evaluate failure diagnoses, performance modeling, and multi-node trace analyses.

#

What problem does this solve?

Translating distributed concepts into concrete production diagnoses. Interviews test if you can pinpoint why a model OOMs or why scaling efficiency degrades.

Core mechanism

Analyze system state using a memory ledger, trace collective traffic patterns, and use roofline models to identify whether a training job is compute-bound, HBM-bandwidth-bound, or network-communication-bound.

What to say in an interview

Always structure your answers by first naming the state (weights, grads, activations), then identifying the network topology/collective involved, and finally quantifying the performance tradeoff or metric win.

Common misunderstanding

Focusing purely on algorithmic details. Interviewers want to see operational understanding: how to configure bucket sizes, how to diagnose link failures, and how to read Nsight profile traces.

Interview Practice

Interview Practice

Use these representative prompts to rehearse mechanisms and tradeoffs. The full Q&A lives in the interview section so this handbook stays concept-first.

#
  • Why does DDP all-reduce gradients instead of parameters?
  • What changes from ZeRO-1 to ZeRO-3?
  • Why should tensor parallel usually stay inside a fast interconnect domain?
  • How do you diagnose communication-bound distributed training?
Runtime Extensions

Where distributed state appears next

Parallel layouts become operational systems when they synchronize online policies or serve requests under failure.

#

RL Infrastructure

Apply sharded learner state and distributed checkpoints to online rollout/update loops.

Systems Runtime

Connect collectives and topology to timeouts, communicators, admission, and failure diagnosis.

Annotated Labs

Code reading curriculum

Each lab includes a starter file, key snippets, line-by-line explanation, common misunderstandings, and interview framing.

#
#LabPageStarter
01DDP Gradient BucketsOpen labStarter folder
02ZeRO / FSDP State ShardingOpen labStarter folder
03Tensor Parallel LinearOpen labStarter folder
04Pipeline Schedule ReadingOpen labStarter folder
05Sequence / Context ParallelOpen labStarter folder
06Expert Parallel RoutingOpen labStarter folder
073D Parallelism PlanOpen labStarter folder
08Topology CommunicationOpen labStarter folder
09Interview ScenariosOpen labStarter folder

Open labs index

References

Official sources and high-quality intuition notes

Use official sources for factual checks and blogs only for supporting intuition.

#