InfraLens

A clear starting point for learning AI infrastructure.

Overview

Lab 07: FlashAttention Mental Model

Annotated code reading lab. Running code is optional.

Concept Goal

Read code to understand the concept

A Transformer block turns token ids into vectors, mixes context with attention, applies per-token nonlinear transformations, and uses residual and normalization layers to keep deep training stable.

Mental Model

Core mechanism

  • naive attention will materialize S x S score/probability matrix.
  • A Transformer block turns token ids into vectors, mixes context with attention, applies per-token nonlinear transformations, and uses residual and normalization layers to keep deep training stable.
  • Explain the problem, the mechanism, the resource tradeoff, the common failure mode, and the measurement that would validate the claim.
Naive attention:
QK^T -> full S x S score matrix -> softmax -> V

FlashAttention intuition:
for each Q block:
  stream K/V blocks
  update running max and denominator
  accumulate output
avoid writing full S x S intermediates to HBM
Starter files

Annotated starter links

These files are reading material first. If you later decide to run them, treat the run as optional validation rather than the main learning path.

Annotated Code Preview

Starter Preview

Excerpt from code/lab-07-flashattention/flashattention_mental_model.py. This preview explains the key idea; the linked starter file is the source of truth.

scores = q @ k.T * scale
probs = torch.softmax(scores, dim=-1)
naive_out = probs @ v

for start in range(0, s, block_size):
    scores_block = (q_row @ k[start:end].T).squeeze(0) * scale
    m_new = torch.maximum(m, scores_block.max())
    alpha = torch.exp(m - m_new)
    p = torch.exp(scores_block - m_new)
    l_new = l * alpha + p.sum()
    out = (out * l * alpha + p @ v[start:end]) / l_new
    m = m_new
    l = l_new
Line-by-line Explanation

Key code blocks

scores = q @ k.T
Naive path creates the S x S matrix.
m_new
Tracks the stable softmax max as blocks arrive.
alpha
Rescales the old accumulator when the running max changes.
l_new
Running denominator for normalized probabilities.
out update
Combines old output accumulator with the current V block contribution.
What to Notice

How to read this code

  • The algorithm computes the same attention result up to numerical tolerance.
  • The win comes from memory IO, not fewer attention FLOPs.
  • Masking and backward pass add implementation complexity in real kernels.
Common Misunderstandings

What this code does not mean

  • “FlashAttention approximates attention.” It is exact within floating-point behavior.
  • “The point is only speed.” The core mechanism is avoiding large HBM intermediates.
Interview Explanation

How to say it out loud

Naive attention writes S x S scores and probabilities. FlashAttention processes K/V in blocks and uses online softmax to keep only running statistics and output accumulators. It preserves the math while reducing HBM reads/writes.

External intuition notes

Additional intuition

Further Reading

Official, paper and practical references