InfraLens

A clear starting point for learning AI infrastructure.

Overview

Lab 03: Tensor Parallel Linear

Annotated code reading lab. Running code is optional.

Related handbook section

Tensor Parallel Linear

This lab maps directly to the handbook section. Read the related handbook section first, then use the lab page and starter file to connect the concept to concrete variables, shapes, APIs, and interview-ready explanations.

Concept Goal

Tensor Parallel Linear

Read tensor parallelism as splitting one layer's matrix multiply across ranks.

Mental Model

Mechanism to keep in mind

  • `column_shard` splits output features.
  • `row_shard` splits input features.
  • `collective` merges partial results when needed.
Annotated Code Preview

Starter preview

Excerpt from code/lab-03-tensor-parallel-linear/tensor_parallel_linear.py. The linked starter file is the source of truth.

Open starter file
# Tensor Parallel Linear
# Annotated reading material. Running this file is optional.
# Source-of-truth focus: Read tensor parallelism as splitting one layer's matrix multiply across ranks.

hidden, intermediate, tp = 4096, 11008, 2
column_shard = (hidden, intermediate // tp)
partial_output = f"x @ W_col_shard{column_shard}"
# Row-parallel phases often need reduce/all-reduce of partial sums.
merged_output = f"all_reduce_sum({partial_output})"

# What to explain while reading:
# - column_shard splits output features.
# - row_shard splits input features.
# - collective merges partial results when needed.
#
# Common traps:
# - TP is not DDP.
# - A sharded linear layer needs communication at layer boundaries.
Line-by-line Explanation

What each block is doing

Setup / contract
`column_shard` splits output features.
Main transition
`row_shard` splits input features.
Interview hook
`collective` merges partial results when needed.
What to Notice

Reading checkpoints

  • TP cuts layer compute/model width.
  • A common heuristic is to map TP collectives to the fastest available interconnect when possible.
  • It does not increase data-parallel batch by itself.
Common Misunderstandings

What this lab prevents

  • TP is not DDP.
  • A sharded linear layer needs communication at layer boundaries.
Interview Explanation

How to say it out loud

Read tensor parallelism as splitting one layer's matrix multiply across ranks. Then explain the code by naming the state being transformed, the axis or shape that matters, and the tradeoff that would appear in a real system.

External intuition notes

Additional intuition

  • Use official docs and papers for API behavior and factual claims; use blogs only to improve the mental picture.
  • If support matrices, performance behavior or backend choices are version-sensitive, check current docs before repeating them.
  • A strong interview answer names the state object, the shape or axis it changes, and the tradeoff it creates.
Further Reading

Official, paper and practical references