InfraLens

A clear starting point for learning AI infrastructure.

Overview

Lab 04: Multi-Head Attention Minimal Code

Annotated code reading lab. Running code is optional.

Related handbook section

Multi-Head Attention Minimal Code

This lab maps directly to the handbook section. Read the related handbook section first, then use the lab page and starter file to connect the concept to concrete variables, shapes, APIs, and interview-ready explanations.

Concept Goal

Multi-Head Attention Minimal Code

Track the split-attend-merge pattern that makes multi-head attention a set of parallel read channels.

Mental Model

Mechanism to keep in mind

  • `split_heads` exposes independent attention subspaces.
  • `attention_per_head` runs the same score/value logic per head.
  • `merge_heads` returns to model width for the output projection.
Annotated Code Preview

Starter preview

Excerpt from code/lab-04-multi-head-attention/multi_head_attention.py. The linked starter file is the source of truth.

Open starter file
# Multi-Head Attention Minimal Code
# Annotated reading material. Running this file is optional.
# Source-of-truth focus: Track the split-attend-merge pattern that makes multi-head attention a set of parallel read channels.

def split_heads(x_shape, heads):
    batch, seq, hidden = x_shape
    return (batch, heads, seq, hidden // heads)

q = split_heads((2, 8, 32), heads=4)
k = split_heads((2, 8, 32), heads=4)
v = split_heads((2, 8, 32), heads=4)
context = (q[0], q[1], q[2], q[3])
merged = (context[0], context[2], context[1] * context[3])

# What to explain while reading:
# - split_heads exposes independent attention subspaces.
# - attention_per_head runs the same score/value logic per head.
# - merge_heads returns to model width for the output projection.
#
# Common traps:
# - Heads are not separate full models.
# - The usual projection size is still tied to D, not H full copies of D.
Line-by-line Explanation

What each block is doing

Setup / contract
`split_heads` exposes independent attention subspaces.
Main transition
`attention_per_head` runs the same score/value logic per head.
Interview hook
`merge_heads` returns to model width for the output projection.
What to Notice

Reading checkpoints

  • Heads share the same sequence positions.
  • The output projection mixes head results.
  • More heads also means more KV cache slices.
Common Misunderstandings

What this lab prevents

  • Heads are not separate full models.
  • The usual projection size is still tied to D, not H full copies of D.
Interview Explanation

How to say it out loud

Track the split-attend-merge pattern that makes multi-head attention a set of parallel read channels. Then explain the code by naming the state being transformed, the axis or shape that matters, and the tradeoff that would appear in a real system.

External intuition notes

Additional intuition

  • Use official docs and papers for API behavior and factual claims; use blogs only to improve the mental picture.
  • If support matrices, performance behavior or backend choices are version-sensitive, check current docs before repeating them.
  • A strong interview answer names the state object, the shape or axis it changes, and the tradeoff it creates.
Further Reading

Official, paper and practical references