Lab 07: FlashAttention Mental Model
Annotated code reading lab. Running code is optional.
GPU Kernels / Foundation Attention
A Transformer block turns token ids into vectors, mixes context with attention, applies per-token nonlinear transformations, and uses residual and normalization layers to keep deep training stable.
Read code to understand the concept
A Transformer block turns token ids into vectors, mixes context with attention, applies per-token nonlinear transformations, and uses residual and normalization layers to keep deep training stable.
Core mechanism
- naive attention will materialize S x S score/probability matrix.
- A Transformer block turns token ids into vectors, mixes context with attention, applies per-token nonlinear transformations, and uses residual and normalization layers to keep deep training stable.
- Explain the problem, the mechanism, the resource tradeoff, the common failure mode, and the measurement that would validate the claim.
Naive attention:
QK^T -> full S x S score matrix -> softmax -> V
FlashAttention intuition:
for each Q block:
stream K/V blocks
update running max and denominator
accumulate output
avoid writing full S x S intermediates to HBM
Annotated starter links
These files are reading material first. If you later decide to run them, treat the run as optional validation rather than the main learning path.
Starter Preview
Excerpt from code/lab-07-flashattention/flashattention_mental_model.py. This preview explains the key idea; the linked starter file is the source of truth.
scores = q @ k.T * scale
probs = torch.softmax(scores, dim=-1)
naive_out = probs @ v
for start in range(0, s, block_size):
scores_block = (q_row @ k[start:end].T).squeeze(0) * scale
m_new = torch.maximum(m, scores_block.max())
alpha = torch.exp(m - m_new)
p = torch.exp(scores_block - m_new)
l_new = l * alpha + p.sum()
out = (out * l * alpha + p @ v[start:end]) / l_new
m = m_new
l = l_newKey code blocks
scores = q @ k.T- Naive path creates the S x S matrix.
m_new- Tracks the stable softmax max as blocks arrive.
alpha- Rescales the old accumulator when the running max changes.
l_new- Running denominator for normalized probabilities.
out update- Combines old output accumulator with the current V block contribution.
How to read this code
- The algorithm computes the same attention result up to numerical tolerance.
- The win comes from memory IO, not fewer attention FLOPs.
- Masking and backward pass add implementation complexity in real kernels.
What this code does not mean
- “FlashAttention approximates attention.” It is exact within floating-point behavior.
- “The point is only speed.” The core mechanism is avoiding large HBM intermediates.
How to say it out loud
Naive attention writes S x S scores and probabilities. FlashAttention processes K/V in blocks and uses online softmax to keep only running statistics and output accumulators. It preserves the math while reducing HBM reads/writes.
Additional intuition
- The FlashAttention paper is the fact base: it is exact attention with IO-aware tiling, not an approximation method. Paper: FlashAttention
- Online-softmax explainers are useful because they make the running max and denominator update memorable before reading the full kernel implementation. Blog: Hugging Face community online softmax explainer
- IO-aware blog explanations are helpful for interviews: the key phrase is changing data movement, not changing the attention formula. Blog: FlashAttention IO-aware explanation
