InfraLens

A clear starting point for learning AI infrastructure.

Overview

Lab 06: Triton Fused Softmax

Annotated code reading lab. Running code is optional.

Concept Goal

Read code to understand the concept

Kernel performance depends on data movement as much as math. Use memory hierarchy, tiling, fusion, coalescing, bank conflicts, and profiler counters to explain whether the workload is bandwidth-bound or compute-bound.

Mental Model

Core mechanism

  • Kernel performance depends on data movement as much as math. Use memory hierarchy, tiling, fusion, coalescing, bank conflicts, and profiler counters to explain whether the workload is bandwidth-bound or compute-bound.
  • Explain the problem, the mechanism, the resource tradeoff, the common failure mode, and the measurement that would validate the claim.
Starter files

Annotated starter links

These files are reading material first. If you later decide to run them, treat the run as optional validation rather than the main learning path.

Annotated Code Preview

Starter Preview

Excerpt from code/lab-06-triton-softmax/triton_softmax.py. This preview explains the key idea; the linked starter file is the source of truth.

@triton.jit
def softmax_kernel(x_ptr, y_ptr, n_cols: tl.constexpr, block_size: tl.constexpr):
    row_id = tl.program_id(0)
    offsets = tl.arange(0, block_size)
    mask = offsets < n_cols
    row = tl.load(x_ptr + row_id * n_cols + offsets, mask=mask, other=-float("inf"))
    row = row - tl.max(row, axis=0)
    numerator = tl.exp(row)
    denominator = tl.sum(numerator, axis=0)
    tl.store(y_ptr + row_id * n_cols + offsets, numerator / denominator, mask=mask)

baseline = torch.softmax(x, dim=-1)
softmax_kernel[(rows,)](x, y, cols, block_size=block_size)
Line-by-line Explanation

Key code blocks

tl.program_id(0)
Selects which row/block this Triton program instance owns.
tl.arange / mask
Creates vector offsets and avoids reading past the row length.
tl.load
Loads a block of row elements from HBM into the program.
tl.max / tl.sum
Performs reductions inside the program.
tl.store
Writes only final softmax output, not every intermediate result.
What to Notice

How to read this code

  • Triton looks Pythonic, but it describes GPU kernels.
  • Some tutorials use next_power_of_2 block sizes for row-wise softmax with masks; that is a kernel-specific heuristic, not a universal Triton constraint.
  • Fusion helps when intermediate HBM writes or launch overhead matter.
  • A custom Triton kernel is still a maintenance responsibility.
Common Misunderstandings

What this code does not mean

  • “Triton automatically optimizes everything.” You still choose block size, masks and memory layout.
  • “Fusion always wins.” Small or unsupported shapes may not benefit.
Interview Explanation

How to say it out loud

Triton fused softmax assigns a row/block to each program. It loads the row, computes stable softmax reductions inside the kernel, and stores only the final output. This reduces intermediate HBM writes compared with composing many separate operations.

External intuition notes

Additional intuition

  • The Triton fused softmax tutorial is the main source for this lab: each program loads, normalizes and stores rows while avoiding separate intermediate tensors. Official: Triton fused softmax tutorial
  • The Triton documentation is the source of truth for the programming model: Triton looks Pythonic, but it exposes block-level parallel operations rather than ordinary Python loops. Official: Triton documentation
  • Community softmax writeups are useful for intuition, but correctness and API facts should stay grounded in Triton docs. Blog: Softmax in OpenAI Triton
Further Reading

Official, paper and practical references