#!/usr/bin/env python3
"""Minimal DDP training starter with synthetic data."""

from __future__ import annotations

import argparse
import os
import time
from pathlib import Path


def parse_args():
    parser = argparse.ArgumentParser(description="DDP conversion starter")
    parser.add_argument("--steps", type=int, default=10)
    parser.add_argument("--batch-size", type=int, default=32)
    parser.add_argument("--input-dim", type=int, default=64)
    parser.add_argument("--hidden-dim", type=int, default=128)
    parser.add_argument("--num-classes", type=int, default=10)
    parser.add_argument("--checkpoint", type=Path, default=Path("checkpoint_ddp.pt"))
    return parser.parse_args()


def setup_distributed(torch):
    import torch.distributed as dist

    # torchrun writes these environment variables before launching your script.
    # RANK is the global process id; WORLD_SIZE is the number of processes in
    # the job; LOCAL_RANK is the process index on this node and is usually used
    # to choose the GPU for this process.
    world_size = int(os.environ.get("WORLD_SIZE", "1"))
    rank = int(os.environ.get("RANK", "0"))
    local_rank = int(os.environ.get("LOCAL_RANK", "0"))
    distributed = world_size > 1
    backend = "nccl" if torch.cuda.is_available() else "gloo"
    if distributed and not dist.is_initialized():
        # This call makes the current Python process join a distributed process
        # group. DDP's gradient all-reduce collectives happen inside this group.
        dist.init_process_group(backend=backend)
    return distributed, rank, world_size, local_rank


def main() -> None:
    try:
        import torch
        import torch.nn as nn
        import torch.nn.functional as F
        import torch.distributed as dist
        from torch.nn.parallel import DistributedDataParallel as DDP
    except ImportError as exc:
        print(f"PyTorch is required for this lab: {exc}")
        return

    args = parse_args()
    distributed, rank, world_size, local_rank = setup_distributed(torch)

    if torch.cuda.is_available():
        # One process normally owns one local GPU. Binding the process explicitly
        # prevents multiple ranks from accidentally using cuda:0.
        torch.cuda.set_device(local_rank)
        device = torch.device("cuda", local_rank)
    else:
        device = torch.device("cpu")

    model = nn.Sequential(
        nn.Linear(args.input_dim, args.hidden_dim),
        nn.ReLU(),
        nn.Linear(args.hidden_dim, args.num_classes),
    ).to(device)

    if distributed:
        # DDP wraps the module and registers autograd hooks on parameters. The
        # forward call still looks local, but backward will launch communication
        # when gradient buckets become ready.
        ddp_kwargs = {"device_ids": [local_rank]} if device.type == "cuda" else {}
        model = DDP(model, **ddp_kwargs)

    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

    # In a real dataset you would usually use DistributedSampler. Here the rank
    # offset makes the synthetic stream different on each rank, which is enough
    # to show that each process owns a data shard.
    generator = torch.Generator(device=device)
    generator.manual_seed(1234 + rank)

    if rank == 0:
        print(f"distributed={distributed} world_size={world_size} backend={'nccl' if device.type == 'cuda' else 'gloo'}")

    for step in range(args.steps):
        start = time.perf_counter()
        x = torch.randn(args.batch_size, args.input_dim, device=device, generator=generator)
        y = torch.randint(0, args.num_classes, (args.batch_size,), device=device, generator=generator)

        optimizer.zero_grad(set_to_none=True)
        logits = model(x)
        loss = F.cross_entropy(logits, y)

        # DDP installs autograd hooks. During backward, gradient buckets are
        # all-reduced so every rank applies the same averaged gradients.
        loss.backward()
        optimizer.step()

        if device.type == "cuda":
            torch.cuda.synchronize()
        if rank == 0:
            print(f"step={step:03d} loss={loss.item():.4f} time_ms={(time.perf_counter() - start) * 1000:.2f}")

    if rank == 0:
        # Only one rank should write shared files. Without this guard, all ranks
        # could race to write the same checkpoint path.
        raw_model = model.module if distributed else model
        torch.save({"model": raw_model.state_dict(), "optimizer": optimizer.state_dict()}, args.checkpoint)
        print(f"saved checkpoint: {args.checkpoint}")

    if distributed:
        dist.destroy_process_group()


if __name__ == "__main__":
    main()
