AI Infra Deep Dive

Transformer Systems

A deep-dive branch of the AI Infra umbrella, focused on tokenization, QKV, attention backends, RoPE, FFN, KV Cache, quantization, and inference mechanics.

What you will learn

Mechanics behind model systems

Use this page to connect token flow, attention shapes, memory behavior, backend choices, and serving constraints without turning the Transformer into a black box.

Reading path

Read the related handbook section first, then use the lab page and starter file to connect the concept to concrete variables, shapes, APIs, and interview-ready explanations.

overview

Overview: Why Transformer Matters

The Transformer architecture revolutionized sequence modeling by replacing sequential recurrence with parallelizable self-attention mechanisms, significantly accelerating training throughput on modern accelerator hardware.

#

What problem does this solve?

RNNs and LSTMs process sequences sequentially, meaning training a sequence of length S requires S sequential updates. The Transformer removes this step dependency, allowing all tokens to attend to all other tokens in parallel during training, which maximizes GPU utilization.

Core mechanism

Instead of passing a hidden state step-by-step, the Transformer projects input embeddings into Query, Key, and Value vectors. By calculating scaled dot-product similarity between all Queries and Keys, it computes contextual routing weights to mix the Values in parallel.

What to say in an interview

Focus on execution speed and scaling: the Transformer turned sequence modeling from a sequential recurrent dependency into a parallelizable matrix multiplication problem, enabling the training of models with billions or trillions of parameters on massive datasets.

Common misunderstanding

Thinking attention is the only important part. In practice, the Feed-Forward Network (FFN) layers contain the majority of the model's parameters and act as a key-value storage for factual knowledge, while residual connections and LayerNorm/RMSNorm are what make deep training stable.

tokenization

Tokenization and Embeddings

Tokenization converts raw characters into discrete subword integer IDs, which are then mapped to continuous dense vectors using an embedding lookup table.

#

What problem does this solve?

Neural networks cannot process raw characters or strings directly. Subword tokenization (using algorithms like Byte-Pair Encoding or SentencePiece) maps text to a fixed vocabulary of IDs, handling out-of-vocabulary words by splitting them into stems or byte sequences.

Core mechanism

The embedding layer is a weight matrix of shape (V, D), where V is the vocabulary size and D is the hidden dimension. An integer token ID acts as an index to lookup a D-dimensional vector. These vectors represent the semantic starting point before any layer transformations.

What to say in an interview

Vocabulary size represents a memory-compute tradeoff: a larger vocabulary V decreases the sequence length S for the same text (reducing attention complexity), but increases the embedding parameters (V * D), which dominates the weight footprint of smaller models.

Common misunderstanding

Assuming that a token always equals a word. Many technical words, code keywords, or foreign characters are split into multiple subwords or individual characters by the tokenizer, which changes the effective sequence length processed by the attention blocks.

block

Decoder-only Transformer Block

Decoder-only blocks process unidirectional context using a causal mask, allowing efficient autoregressive text generation and streamlined model training.

#

What problem does this solve?

Generative models must predict the next token using only prior tokens. While encoder-decoder architectures use separate networks and cross-attention, decoder-only models simplify memory management and scaling by combining all tokens into a single causally-masked sequence.

Core mechanism

A decoder-only layer uses Pre-LN residual streams: input -> RMSNorm -> Causal Attention -> Add -> RMSNorm -> FFN -> Add. This setup ensures that normalized activations are passed to sub-layers while gradients can flow directly through the main addition path.

What to say in an interview

Explain that Pre-LN (normalization on the input branch before sub-layers) is critical for scaling. Unlike Post-LN which normalizes the main residual path and requires a warm-up phase to avoid gradient explosion, Pre-LN allows stable training from initialization.

Common misunderstanding

Assuming that decoder-only blocks are less capable of processing input prompts because they mask future tokens. The causal mask only prevents future tokens from affecting past positions, which aligns perfectly with how autoregressive generation is computed during serving.

qkv

Q/K/V and Attention Shapes

Self-attention projects representations into Query, Key, and Value spaces, computing similarity scores to contextually routing information across the sequence.

#

What problem does this solve?

To compute context-aware token representations, the model needs to query the past. Projection weights transform each token's vector into a Query (what it seeks), a Key (what it contains), and a Value (the content it contributes to matching queries).

Core mechanism

For input of shape (B, S, D), projections yield Q, K, V of shape (B, H, S, Dh) where Dh = D/H. Calculating similarity (Q @ K.T) yields scores of shape (B, H, S, S). Softmax-normalized scores are multiplied by V to output (B, H, S, Dh), then projected back to (B, S, D).

Formula: scaled dot-product attention Attention(Q,K,V) = softmax(QK^T / sqrt(d_k) + mask) V
  • Q: query vectors for positions asking what to read.
  • K: key vectors for positions being compared against.
  • V: value vectors carrying the content to mix.
  • d_k: key/head dimension used to scale scores.
  • QK^T: similarity score matrix; a causal mask blocks future positions before softmax.
tokens
  -> embedding + position: X (B, S, D)
  -> Q/K/V projections:   Q,K,V (B, H, S, Dh)
  -> scores QK^T:         scores (B, H, S, S)
  -> mask + softmax + V:  context (B, S, D)
What to say in an interview

Walk through the shapes at each step. Emphasize that the scaling factor 1/sqrt(Dh) is crucial: as the head dimension Dh grows, dot-products grow in magnitude, which would push softmax into flat regions with near-zero gradients without scaling.

Common misunderstanding

Thinking Q, K, and V are separate source arrays. They are linear transformations of the exact same input representation X. In encoder-decoder cross-attention, Q is projected from the decoder, while K and V come from the encoder's output.

mha

Multi-Head Attention

Multi-Head Attention partitions the model representation into independent head subspaces, allowing the model to target multiple dependency relationships simultaneously.

#

What problem does this solve?

A single attention matrix forces the model to average all contextual lookups across a sequence. Multi-head attention projects the input into H different subspaces, enabling heads to independently learn different features like grammar, subject-verb relations, or factual links.

Core mechanism

MHA splits the representation dimension D into H heads of size Dh = D/H. During inference, storing KV vectors for every head creates massive cache memory overhead. Modern architectures use Multi-Query (MQA) or Grouped-Query (GQA) where query heads share a single or grouped set of K/V heads.

VariantKV organizationServing implicationTradeoff
MHAEach query head has its own K/V head.Largest KV Cache and decode bandwidth demand.Maximum head-specific capacity.
MQAAll query heads share one K/V head.Smallest KV Cache among these variants.Most aggressive sharing can reduce quality or compatibility.
GQAGroups of query heads share K/V heads.Substantial KV Cache reduction for serving.Balances cache savings with representation capacity.
What to say in an interview

Explain the serving trade-offs: MQA uses 1 KV head for all Q heads, which saves cache memory but can hurt representation capacity. GQA groups query heads to share a smaller subset of KV heads (e.g. 8 Q heads per KV head), offering a sweet spot in memory and performance.

Common misunderstanding

Thinking multi-head attention increases computational complexity compared to a hypothetical single head of dimension D. For a fixed dimension D, splitting it into H heads preserves the total operations, though layout permutations add minor overhead.

rope

Positional Encoding / RoPE

Rotary Position Embeddings inject position information into attention by rotating Query and Key vectors in 2D slices, encoding relative distances geometrically.

#

What problem does this solve?

Self-attention is mathematically permutation-invariant; it treats sequences as unordered bags of tokens. Positional encodings are required to restore sequence ordering, enabling the model to distinguish between different word orders.

Core mechanism

RoPE (Rotary Position Embedding) splits Query and Key vectors into 2D coordinate pairs. For a token at index m, it rotates each pair by an angle m * theta. When computing attention, the dot-product between rotated Queries and Keys naturally decodes their relative distance (m - n).

What to say in an interview

Explain that RoPE is superior to absolute positional embeddings because it encodes relative distances directly in the attention computation. This allows length extrapolation techniques (like NTK-aware scaling) to work by modifying the base frequency theta.

Common misunderstanding

Thinking RoPE is added to the word embeddings at the bottom of the network. Unlike absolute encodings, RoPE is applied dynamically to the Queries and Keys at every single self-attention layer, keeping positional information intact as representation flows through deep blocks.

ffn

MLP / FFN / SwiGLU

Feed-forward networks perform per-token non-linear transformations, serving as a key-value associative memory to store and retrieve factual knowledge.

#

What problem does this solve?

Attention only routes information between tokens but does not execute complex non-linear computations on individual representations. The Feed-Forward Network (FFN/MLP) provides the computational capacity to refine token representations and store factual knowledge.

Core mechanism

Instead of standard ReLU/GELU activations, modern models utilize SwiGLU blocks. A SwiGLU block uses three projection matrices (gate_proj, up_proj, and down_proj): FFN_SwiGLU = (Swish(X @ W_gate) * (X @ W_up)) @ W_down, improving convergence stability and representation quality.

What to say in an interview

Explain the dimension math: SwiGLU requires about 1.5x more parameter capacity than standard FFN for a hidden dimension H. To keep total parameter count and FLOPs equivalent to a standard FFN (where H = 4D), the SwiGLU hidden dimension is typically scaled down to ~8/3 D.

Common misunderstanding

Believing the attention layers store most of the model's factual knowledge. In large models, FFN blocks make up roughly 2/3 of the total parameters. They act as key-value memory banks, while attention operates as the router selecting which memories to read.

norm

Residual and LayerNorm

Residual connections combined with normalization layers stabilize activation distributions, preventing vanishing or exploding gradients in deep networks.

#

What problem does this solve?

In networks with dozens of layers, gradients vanish or explode during backpropagation, stopping learning. Residual additions allow gradients to bypass layer transformations. Normalization prevents hidden state magnitudes from drifting across layers.

Core mechanism

LayerNorm normalizes features across the hidden dimension for each token. Modern architectures use RMSNorm, which calculates the Root Mean Square of features instead of full variance, and scales inputs by a learned gain factor, saving computing time.

NormalizationStatistic and axisDecoder-model roleBoundary
LayerNormMean and variance across features for each token.Stabilizes each token representation independently.Includes centering and scaling work.
RMSNormRoot mean square across features for each token.Common lighter normalization in modern decoder models.Does not subtract the feature mean.
BatchNormBatch-dependent statistics.Not the standard normalization for autoregressive decoders.Couples examples in a batch and is not interchangeable here.
What to say in an interview

Highlight the computational benefit of RMSNorm: by removing the mean computation step, RMSNorm saves memory bandwidth and kernel execution time (avoiding two passes over the data) while achieving identical convergence quality as standard LayerNorm.

Common misunderstanding

Confusing LayerNorm with BatchNorm. BatchNorm normalizes features across the batch dimension, making token representations dependent on other batch items. LayerNorm normalizes independently per token, preserving causal isolation during text generation.

training

Training Loop and Loss

Autoregressive training utilizes teacher forcing and cross-entropy loss to optimize next-token predictions in parallel across all sequence positions.

#

What problem does this solve?

To train models efficiently, we must optimize sequence predictions. Teacher forcing feeds target tokens as input, and a causal attention mask ensures position i cannot see future positions, enabling parallel loss computation across the sequence.

Core mechanism

Input tokens pass through the network to yield logits of shape (B, S, V). The logits are compared against target tokens (shifted by one position) using Cross-Entropy Loss. Gradients are computed, weights are updated, and the loss stabilizes over iterations.

What to say in an interview

Explain that teacher forcing speeds up training but creates "exposure bias" because the model never learns to correct its own errors during training. During inference, it relies on its own generated tokens, where small errors can compound over steps.

Common misunderstanding

Thinking training runs sequentially. Because the entire ground-truth sequence is available during training, we compute predictions and loss for all token positions concurrently in a single forward pass, which is highly efficient.

kv-cache

KV Cache and Autoregressive Inference

KV Cache trades GPU memory space to save computation by storing Key and Value states, reducing incremental token generation complexity from O(S^2) to O(S).

#

What problem does this solve?

In incremental token generation, we only need to predict the next token. Without caching, a forward pass would recompute Queries, Keys, and Values for all past tokens at every step, creating O(S^2) computation and heavy memory bottlenecks.

Core mechanism

Inference consists of Prefill (computing K/V for all prompt tokens in a compute-bound forward pass) and Decode (incremental generation, where we only compute Q/K/V for the single new token, append K/V to the cache, and attend to the full cache).

PhaseInput and state changeDominant metricTypical pressure
PrefillProcesses the prompt in parallel and creates the initial KV Cache.TTFTFull-sequence attention and prompt compute.
DecodeProcesses one generated token and appends its K/V to the cache.TPOTRepeated cache reads, bandwidth, and scheduling.
Formula: KV Cache memory 2 * layers * batch * seq_len * kv_heads * head_dim * bytes
  • 2: key and value tensors.
  • layers: each decoder layer stores its own cache.
  • batch and seq_len: concurrency and cached context length.
  • kv_heads and head_dim: K/V representation size; MQA/GQA may reduce KV heads.
  • bytes: storage precision per element.
request prompt
  -> prefill computes K/V for all prompt tokens
  -> decode step computes Q for the new token
  -> Q attends to cached K/V plus current K/V
  -> append current K/V and stream next token
What to say in an interview

Discuss the KV Cache memory footprint calculations. As batch size and sequence length grow, the cache size can grow to tens of gigabytes, making KV Cache the primary limiting factor for batch size and serving density.

Common misunderstanding

Thinking the KV Cache reduces memory usage. It actually exchanges GPU memory footprint (storing gigabytes of key/value states) to save massive amounts of compute and HBM read/write cycles during generation.

backends

Attention Backends: SDPA / FlashAttention / PagedAttention

Optimized backends rewrite memory access patterns to eliminate intermediate matrix bottlenecks, maximizing token generation speed.

#

What problem does this solve?

Standard attention writes the large S x S similarity score matrix to GPU HBM and reads it back for softmax and value multiplication. This memory-bandwidth bottleneck slows down processing as sequence length S grows.

Core mechanism

FlashAttention tiles Q, K, and V matrices into blocks, loading them into fast GPU SRAM, and computes online softmax statistics to avoid writing the S x S matrix to HBM. PagedAttention divides the KV Cache into virtual pages, preventing memory fragmentation.

NameRoleWhat it changesDo not confuse it with
SDPAFramework attention interface.Dispatches scaled-dot-product attention to an eligible backend.A promise that every call uses one particular kernel.
FlashAttentionExact IO-aware attention algorithm and kernel family.Tiles attention and avoids materializing the full score matrix in HBM.An approximate or different model architecture.
PagedAttentionServing-time KV Cache allocation strategy.Maps logical cache blocks to physical blocks to reduce fragmentation.A replacement attention formula or faster single kernel.
Naive attention:
  QK^T -> full S x S score matrix -> softmax -> V

FlashAttention intuition:
  for each Q block:
    stream K/V blocks
    update online softmax statistics
    accumulate output
  avoid writing the full S x S matrix to HBM

This is a simplified mental model. Kernel support still depends on dtype, mask/layout, hardware and library version.

What to say in an interview

FlashAttention is an exact mathematical implementation of standard attention, not an approximation. It is an IO-aware algorithm that reduces memory access overhead by keeping intermediate scores in fast SRAM and recomputing them in the backward pass.

Common misunderstanding

Assuming PagedAttention speeds up individual kernel executions. PagedAttention is a system-level memory management technique that allows serving systems to allocate cache in page blocks, eliminating external fragmentation and enabling higher batch concurrency.

quantization

Quantization and Inference Memory

Quantization compresses weights and activations to lower-precision formats, reducing memory footprint and accelerating inference speed.

#

What problem does this solve?

Frontier LLMs have hundreds of billions of parameters. Running them in float16 requires massive GPU memory. Quantizing parameters to 8-bit or 4-bit numbers shrinks the weight footprint, enabling models to fit on fewer GPUs and speeding up memory reads.

Core mechanism

Weight-only quantization (e.g. W4A16, GPTQ, AWQ) compresses weights to 4-bit and dequantizes them to float16 during matrix multiplication. Weight-Activation quantization (e.g. W8A8, SmoothQuant) quantizes both sides to 8-bit to run integer operations directly on Tensor Cores.

Quantization pathQuantized stateLikely benefitValidation requirement
Weight-only (W4A16 / GPTQ / AWQ)Weights; activations generally remain higher precision.Reduces model footprint and decode weight bandwidth.Measure dequantization overhead and output quality.
Weight-activation (W8A8 / SmoothQuant)Weights and activations.Enables lower-precision tensor math, often useful for prefill.Confirm hardware kernels and activation outlier behavior.
KV-cache quantizationStored keys and values during serving.Increases context or concurrency capacity and reduces decode reads.Measure quality, scale metadata, and backend support.
What to say in an interview

Explain the execution difference: weight-only quantization accelerates the memory-bandwidth-bound decode phase by reducing weight retrieval volume from HBM. Weight-activation quantization accelerates the compute-bound prefill phase by utilizing fast INT8 tensor math.

Common misunderstanding

Assuming quantization always improves latency. If the batch size is small, weight-only quantization requires dequantization to float16 at runtime, which adds compute overhead. The speedup only registers if weight retrieval is the dominant bottleneck.

Interview Practice

Interview Practice

Use these representative prompts to rehearse mechanisms and tradeoffs. The full Q&A lives in the interview section so this handbook stays concept-first.

#
  • What are Q, K and V in attention?
  • Why does decoder-only attention need a causal mask?
  • Why does KV cache make autoregressive inference faster?
  • How do SDPA and FlashAttention relate?
Runtime Extensions

From model primitives to system stages

KV cache, attention backends, and decode behavior feed higher-level serving and practice tracks.

#

Multimodal Serving

See how autoregressive tokens interact with diffusion, VAE, audio, or video stages.

Coding Practice

Reuse attention, KV cache, backend, and quantization primitives in estimation drills.

Annotated Labs

Code reading curriculum

Each lab includes a starter file, key snippets, line-by-line explanation, common misunderstandings, and interview framing.

#
#LabPageStarter
01Tokenization and Embedding LookupOpen labStarter folder
02Q/K/V Shape WalkthroughOpen labStarter folder
03Causal Mask and Attention ScoresOpen labStarter folder
04Multi-Head Attention Minimal CodeOpen labStarter folder
05RoPE Mental ModelOpen labStarter folder
06FFN / SwiGLU Parameter CountOpen labStarter folder
07Training Loop and Cross Entropy LossOpen labStarter folder
08KV Cache Step-by-StepOpen labStarter folder
09Attention Backend Comparison ReadingOpen labStarter folder
10Quantization Reading LabOpen labStarter folder

Open labs index

References

Official sources and high-quality intuition notes

Use official sources for factual checks and blogs only for supporting intuition.

#