#!/usr/bin/env python3
"""Educational Triton fused softmax starter."""

from __future__ import annotations

import argparse
import sys


def parse_args():
    parser = argparse.ArgumentParser(description="Triton fused softmax starter")
    parser.add_argument("--rows", type=int, default=1024)
    parser.add_argument("--cols", type=int, default=512)
    parser.add_argument("--warmup", type=int, default=10)
    parser.add_argument("--iters", type=int, default=50)
    return parser.parse_args()


def main() -> None:
    try:
        import torch
    except ImportError as exc:
        print(f"PyTorch is required for this lab: {exc}")
        return

    try:
        import triton
        import triton.language as tl
    except ImportError:
        print("Triton is not installed. Install Triton in a CUDA-capable environment to run this lab.")
        return

    if not torch.cuda.is_available():
        print("CUDA is not available. Triton GPU kernels cannot run in this environment.")
        return

    args = parse_args()

    @triton.jit
    def softmax_kernel(x_ptr, y_ptr, n_cols: tl.constexpr, block_size: tl.constexpr):
        # A Triton "program" is closer to a CUDA thread block than to a Python
        # function call. Here program_id(0) selects one row of the matrix.
        row_id = tl.program_id(0)

        # offsets is a vector of column positions handled by this program.
        # block_size is a compile-time constant so Triton can generate a fixed
        # vectorized kernel for the row.
        offsets = tl.arange(0, block_size)
        mask = offsets < n_cols

        # tl.load reads a full row into SRAM/register-backed program values.
        # Masked lanes use -inf so they do not affect max/sum for non power-of-2
        # column counts.
        row = tl.load(x_ptr + row_id * n_cols + offsets, mask=mask, other=-float("inf"))

        # max, exp, sum and divide are fused inside one kernel. The teaching
        # point is not that this exact kernel is always faster, but that fusion
        # avoids writing intermediate tensors back to HBM between operations.
        row = row - tl.max(row, axis=0)
        numerator = tl.exp(row)
        denominator = tl.sum(numerator, axis=0)
        out = numerator / denominator
        tl.store(y_ptr + row_id * n_cols + offsets, out, mask=mask)

    device = "cuda"
    x = torch.randn(args.rows, args.cols, device=device, dtype=torch.float32)
    y = torch.empty_like(x)
    # The PyTorch result is a semantic baseline: the Triton kernel should
    # compute the same softmax, only with a different kernel schedule.
    baseline = torch.softmax(x, dim=-1)

    block_size = triton.next_power_of_2(args.cols)
    if block_size > 131072:
        print("Column count is too large for this simple starter kernel.")
        sys.exit(0)

    # One program per row. If rows=1024, Triton launches 1024 independent
    # program instances, each running the vectorized row softmax above.
    grid = (args.rows,)
    softmax_kernel[grid](x, y, args.cols, block_size=block_size)
    torch.cuda.synchronize()
    max_error = (baseline - y).abs().max().item()
    print(f"max_error={max_error:.6e}")

    def bench(fn):
        for _ in range(args.warmup):
            fn()
        torch.cuda.synchronize()
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        start.record()
        for _ in range(args.iters):
            fn()
        end.record()
        torch.cuda.synchronize()
        return start.elapsed_time(end) / args.iters

    torch_ms = bench(lambda: torch.softmax(x, dim=-1))
    triton_ms = bench(lambda: softmax_kernel[grid](x, y, args.cols, block_size=block_size))
    print(f"torch_softmax_ms={torch_ms:.4f}")
    print(f"triton_softmax_ms={triton_ms:.4f}")
    print("Fusion intuition: max, exp, sum and divide stay inside one kernel instead of materializing intermediate tensors.")


if __name__ == "__main__":
    main()
