Lab 06: Triton Fused Softmax
Annotated code reading lab. Running code is optional.
GPU Kernels
Kernel performance depends on data movement as much as math. Use memory hierarchy, tiling, fusion, coalescing, bank conflicts, and profiler counters to explain whether the workload is bandwidth-bound or compute-bound.
Read code to understand the concept
Kernel performance depends on data movement as much as math. Use memory hierarchy, tiling, fusion, coalescing, bank conflicts, and profiler counters to explain whether the workload is bandwidth-bound or compute-bound.
Core mechanism
- Kernel performance depends on data movement as much as math. Use memory hierarchy, tiling, fusion, coalescing, bank conflicts, and profiler counters to explain whether the workload is bandwidth-bound or compute-bound.
- Explain the problem, the mechanism, the resource tradeoff, the common failure mode, and the measurement that would validate the claim.
Annotated starter links
These files are reading material first. If you later decide to run them, treat the run as optional validation rather than the main learning path.
Starter Preview
Excerpt from code/lab-06-triton-softmax/triton_softmax.py. This preview explains the key idea; the linked starter file is the source of truth.
@triton.jit
def softmax_kernel(x_ptr, y_ptr, n_cols: tl.constexpr, block_size: tl.constexpr):
row_id = tl.program_id(0)
offsets = tl.arange(0, block_size)
mask = offsets < n_cols
row = tl.load(x_ptr + row_id * n_cols + offsets, mask=mask, other=-float("inf"))
row = row - tl.max(row, axis=0)
numerator = tl.exp(row)
denominator = tl.sum(numerator, axis=0)
tl.store(y_ptr + row_id * n_cols + offsets, numerator / denominator, mask=mask)
baseline = torch.softmax(x, dim=-1)
softmax_kernel[(rows,)](x, y, cols, block_size=block_size)Key code blocks
tl.program_id(0)- Selects which row/block this Triton program instance owns.
tl.arange / mask- Creates vector offsets and avoids reading past the row length.
tl.load- Loads a block of row elements from HBM into the program.
tl.max / tl.sum- Performs reductions inside the program.
tl.store- Writes only final softmax output, not every intermediate result.
How to read this code
- Triton looks Pythonic, but it describes GPU kernels.
- Some tutorials use
next_power_of_2block sizes for row-wise softmax with masks; that is a kernel-specific heuristic, not a universal Triton constraint. - Fusion helps when intermediate HBM writes or launch overhead matter.
- A custom Triton kernel is still a maintenance responsibility.
What this code does not mean
- “Triton automatically optimizes everything.” You still choose block size, masks and memory layout.
- “Fusion always wins.” Small or unsupported shapes may not benefit.
How to say it out loud
Triton fused softmax assigns a row/block to each program. It loads the row, computes stable softmax reductions inside the kernel, and stores only the final output. This reduces intermediate HBM writes compared with composing many separate operations.
Additional intuition
- The Triton fused softmax tutorial is the main source for this lab: each program loads, normalizes and stores rows while avoiding separate intermediate tensors. Official: Triton fused softmax tutorial
- The Triton documentation is the source of truth for the programming model: Triton looks Pythonic, but it exposes block-level parallel operations rather than ordinary Python loops. Official: Triton documentation
- Community softmax writeups are useful for intuition, but correctness and API facts should stay grounded in Triton docs. Blog: Softmax in OpenAI Triton
