#!/usr/bin/env python3
"""Educational FlashAttention online-softmax walkthrough.

This is intentionally small and readable, not a fast implementation.
"""

from __future__ import annotations

import argparse
import math


def parse_args():
    parser = argparse.ArgumentParser(description="FlashAttention mental model starter")
    parser.add_argument("--seq-len", type=int, default=8)
    parser.add_argument("--head-dim", type=int, default=4)
    parser.add_argument("--block-size", type=int, default=4)
    parser.add_argument("--seed", type=int, default=0)
    return parser.parse_args()


def main() -> None:
    try:
        import torch
    except ImportError as exc:
        print(f"PyTorch is required for this lab: {exc}")
        return

    args = parse_args()
    torch.manual_seed(args.seed)

    s = args.seq_len
    d = args.head_dim
    q = torch.randn(s, d)
    k = torch.randn(s, d)
    v = torch.randn(s, d)
    scale = 1.0 / math.sqrt(d)

    # Naive attention materializes the full S x S score matrix and then the full
    # S x S probability matrix. This is the memory object FlashAttention avoids
    # keeping in HBM for long sequences.
    scores = q @ k.T * scale
    probs = torch.softmax(scores, dim=-1)
    naive_out = probs @ v

    outputs = []
    for row in range(s):
        q_row = q[row : row + 1]
        # m is the running maximum used for numerically stable softmax. l is
        # the running denominator. out is the normalized output accumulator for
        # this query row.
        m = torch.tensor(-float("inf"))
        l = torch.tensor(0.0)
        out = torch.zeros(d)

        # The mental model: visit K/V in blocks, update the softmax statistics,
        # and never store a complete row of attention probabilities.
        for start in range(0, s, args.block_size):
            end = min(start + args.block_size, s)
            scores_block = (q_row @ k[start:end].T).squeeze(0) * scale

            # If the new block contains a larger score, old contributions must
            # be rescaled because stable softmax is relative to the max.
            m_new = torch.maximum(m, scores_block.max())
            alpha = torch.exp(m - m_new)
            p = torch.exp(scores_block - m_new)
            l_new = l * alpha + p.sum()

            # Keep the normalized output accumulator consistent after the max
            # changes. This line is the core online-softmax idea in a readable
            # form; real FlashAttention kernels apply the same recurrence with
            # tiled Q/K/V and carefully managed on-chip memory.
            out = (out * l * alpha + p @ v[start:end]) / l_new
            m = m_new
            l = l_new

        outputs.append(out)

    blockwise_out = torch.stack(outputs, dim=0)
    max_error = (naive_out - blockwise_out).abs().max().item()

    score_bytes = scores.numel() * scores.element_size()
    prob_bytes = probs.numel() * probs.element_size()
    print(f"seq_len={s} head_dim={d} block_size={args.block_size}")
    print(f"naive_score_matrix={scores.shape} bytes={score_bytes}")
    print(f"naive_probability_matrix={probs.shape} bytes={prob_bytes}")
    print(f"max_error={max_error:.6e}")
    print("Online softmax avoids keeping the full S x S probability matrix alive at once.")


if __name__ == "__main__":
    main()
