#!/usr/bin/env python3
"""ZeRO/FSDP-style memory sharding estimates."""

from __future__ import annotations

import argparse


def parse_args():
    parser = argparse.ArgumentParser(description="ZeRO/FSDP memory accounting starter")
    parser.add_argument("--params-b", type=float, default=7.0, help="Parameter count in billions")
    parser.add_argument("--precision-bytes", type=float, default=2.0, help="Bytes for parameters and gradients")
    parser.add_argument("--adam-bytes", type=float, default=8.0, help="Bytes per parameter for Adam moments")
    parser.add_argument("--dp-degree", type=int, default=8)
    return parser.parse_args()


def gb(value: float) -> float:
    return value / 1e9


def main() -> None:
    args = parse_args()
    params = args.params_b * 1e9
    dp = max(1, args.dp_degree)

    # Start from a DDP-style mental model: every data-parallel rank owns the
    # same full parameters, gradients, and optimizer states.
    param_gb = gb(params * args.precision_bytes)
    grad_gb = gb(params * args.precision_bytes)
    adam_gb = gb(params * args.adam_bytes)

    replicated = param_gb + grad_gb + adam_gb

    # ZeRO-1 shards only optimizer states, so parameters and gradients remain
    # fully replicated. This helps Adam memory but does not solve parameter
    # replication.
    zero1 = param_gb + grad_gb + adam_gb / dp

    # ZeRO-2 also shards gradients. Backward communication is commonly expressed
    # as reduce-scatter: aggregate gradients and leave each rank with its shard.
    zero2 = param_gb + grad_gb / dp + adam_gb / dp

    # ZeRO-3/FSDP shards parameters too. Forward/backward need parameter
    # all-gather before computation that requires full weights, then gradients
    # are reduced/scattered after use. This is the communication-for-memory
    # tradeoff the lab is meant to make visible.
    zero3 = param_gb / dp + grad_gb / dp + adam_gb / dp

    print(f"params={args.params_b:.2f}B dp_degree={dp}")
    print(f"param_memory={param_gb:.2f} GB")
    print(f"grad_memory={grad_gb:.2f} GB")
    print(f"adam_memory={adam_gb:.2f} GB")
    print()
    print(f"DDP replicated state: {replicated:.2f} GB per GPU")
    print(f"ZeRO-1 estimate:      {zero1:.2f} GB per GPU  # shard optimizer states")
    print(f"ZeRO-2 estimate:      {zero2:.2f} GB per GPU  # shard optimizer states + gradients")
    print(f"ZeRO-3 estimate:      {zero3:.2f} GB per GPU  # shard params + gradients + optimizer states")
    print()
    print("Communication intuition:")
    print("- ZeRO-3/FSDP needs parameter all-gather before computation that requires full parameters.")
    print("- Gradients are commonly reduce-scattered so each rank keeps only its shard.")
    print("- Real peak memory includes gather buckets, prefetching, activations and allocator effects.")


if __name__ == "__main__":
    main()
