#!/usr/bin/env python3
"""Minimal single-GPU/CPU training loop with memory logging."""

from __future__ import annotations

import argparse
import time
from pathlib import Path


def parse_args():
    parser = argparse.ArgumentParser(description="Single GPU training loop starter")
    parser.add_argument("--steps", type=int, default=20)
    parser.add_argument("--batch-size", type=int, default=32)
    parser.add_argument("--input-dim", type=int, default=64)
    parser.add_argument("--hidden-dim", type=int, default=128)
    parser.add_argument("--num-classes", type=int, default=10)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--checkpoint", type=Path, default=Path("checkpoint_single_gpu.pt"))
    return parser.parse_args()


def cuda_memory_line(torch) -> str:
    if not torch.cuda.is_available():
        return "cuda_memory=not_available"
    # allocated is memory currently held by tensors. reserved is memory held by
    # PyTorch's CUDA caching allocator for reuse. reserved can stay high after a
    # tensor dies, which is normal and often surprises people reading nvidia-smi.
    allocated = torch.cuda.memory_allocated() / 1e6
    reserved = torch.cuda.memory_reserved() / 1e6
    max_allocated = torch.cuda.max_memory_allocated() / 1e6
    return f"allocated={allocated:.1f}MB reserved={reserved:.1f}MB max={max_allocated:.1f}MB"


def main() -> None:
    try:
        import torch
        import torch.nn as nn
        import torch.nn.functional as F
    except ImportError as exc:
        print(f"PyTorch is required for this lab: {exc}")
        return

    args = parse_args()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # A tiny MLP is enough to expose the training-loop lifecycle. Real
    # Transformer blocks have attention/FFN modules, but autograd, gradients and
    # optimizer states appear in the same phases.
    model = nn.Sequential(
        nn.Linear(args.input_dim, args.hidden_dim),
        nn.GELU(),
        nn.Linear(args.hidden_dim, args.num_classes),
    ).to(device)

    # AdamW is intentionally used because it creates optimizer states. This
    # mirrors large-model training, where optimizer state can exceed parameter
    # memory and motivates ZeRO/FSDP sharding.
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)

    print(f"device={device}")
    for step in range(args.steps):
        if device.type == "cuda":
            torch.cuda.reset_peak_memory_stats()
        start = time.perf_counter()

        # Synthetic data keeps the lab focused on system mechanics. The input is
        # a batch of feature vectors; the output of the model is class logits.
        x = torch.randn(args.batch_size, args.input_dim, device=device)
        y = torch.randint(0, args.num_classes, (args.batch_size,), device=device)

        # PyTorch accumulates gradients by default. Clearing them here means
        # each loop iteration represents one independent optimizer step. If you
        # remove this line, you are doing gradient accumulation.
        optimizer.zero_grad(set_to_none=True)

        # Forward creates logits and saves activation tensors needed by backward.
        # This is where activation memory comes from; checkpointing trades extra
        # recomputation for keeping fewer of these tensors.
        logits = model(x)
        loss = F.cross_entropy(logits, y)

        # Backward traverses the autograd graph and fills parameter.grad tensors.
        # In a DDP model, this same call also triggers gradient synchronization
        # hooks when gradient buckets become ready.
        loss.backward()

        # AdamW reads parameters and gradients, then creates/updates optimizer
        # state such as moving averages before writing updated parameters.
        optimizer.step()

        if device.type == "cuda":
            torch.cuda.synchronize()
        elapsed_ms = (time.perf_counter() - start) * 1000
        print(f"step={step:03d} loss={loss.item():.4f} time_ms={elapsed_ms:.2f} {cuda_memory_line(torch)}")

    # A training checkpoint needs optimizer state if you want to resume the
    # training trajectory, not just load the final weights for inference.
    torch.save(
        {
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict(),
            "steps": args.steps,
            "args": vars(args),
        },
        args.checkpoint,
    )
    print(f"saved checkpoint: {args.checkpoint}")


if __name__ == "__main__":
    main()
