#!/usr/bin/env python3
"""Rough Transformer parameter and memory accounting.

This script is intentionally interview-grade rather than model-exact. It uses
common approximations:
- attention params per layer ~= 4 * D^2 for Q, K, V, O projections
- FFN params per layer ~= m * D^2
- embedding params ~= vocab_size * D
- Adam states ~= 8 bytes/param by default for first and second moments
- KV Cache ~= batch * seq * layers * heads * head_dim * 2(K,V) * bytes

It excludes activations, allocator fragmentation, temporary workspaces and
communication buffers unless you add your own estimates.
"""

from __future__ import annotations

import argparse


def gb(bytes_value: float) -> float:
    return bytes_value / 1e9


def fmt_params(value: float) -> str:
    if value >= 1e9:
        return f"{value / 1e9:.3f}B"
    return f"{value / 1e6:.2f}M"


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Transformer memory accounting starter")
    parser.add_argument("--hidden-dim", type=int, default=4096)
    parser.add_argument("--num-layers", type=int, default=32)
    parser.add_argument("--num-heads", type=int, default=32)
    parser.add_argument("--vocab-size", type=int, default=32000)
    parser.add_argument("--ffn-multiplier", type=float, default=8.0)
    parser.add_argument("--batch-size", type=int, default=1)
    parser.add_argument("--seq-len", type=int, default=2048)
    parser.add_argument("--precision-bytes", type=float, default=2.0)
    parser.add_argument("--adam-bytes", type=float, default=8.0)
    return parser.parse_args()


def main() -> None:
    args = parse_args()
    d = args.hidden_dim
    layers = args.num_layers
    heads = args.num_heads
    head_dim = d / heads

    # Read these formulas as a shape ledger, not as an exact model parser.
    # In a standard decoder block, Q, K, V and output projections are each
    # roughly D x D. Splitting into heads changes how the D dimension is
    # partitioned during attention; it does not multiply the projection matrix
    # count by the number of heads again.
    attention_per_layer = 4 * d * d

    # FFN implementations differ: classic Transformer uses two large matrices,
    # while gated variants such as SwiGLU use three. The multiplier compresses
    # those details into one interview-friendly approximation.
    ffn_per_layer = args.ffn_multiplier * d * d

    # Token embeddings are a vocab-size by hidden-dim lookup table. Large vocab
    # or untied output heads can make this term visible, but decoder layers
    # usually dominate for multi-billion-parameter models.
    embedding = args.vocab_size * d
    total_params = layers * (attention_per_layer + ffn_per_layer) + embedding

    # Training state is the bridge from "parameter count" to "why ZeRO/FSDP
    # exists". A DDP-style setup replicates these tensors on every data-parallel
    # rank unless a sharding strategy changes the layout.
    param_memory = total_params * args.precision_bytes
    grad_memory = total_params * args.precision_bytes

    # Adam keeps first and second moment tensors. With FP32 moments this is often
    # approximated as 8 bytes per parameter, which is why Adam is much more
    # memory-hungry than SGD-style updates.
    adam_memory = total_params * args.adam_bytes
    training_state_memory = param_memory + grad_memory + adam_memory

    # KV Cache is an inference-serving object, not a training optimizer state.
    # It grows with batch, sequence length, layer count, and the K/V head shape.
    # GQA/MQA changes the number of KV heads; this simple starter assumes KV
    # heads equal attention heads.
    kv_cache_bytes = (
        args.batch_size
        * args.seq_len
        * layers
        * heads
        * head_dim
        * 2
        * args.precision_bytes
    )

    print("Transformer parameter estimate")
    print("--------------------------------")
    print(f"hidden_dim D:              {d}")
    print(f"num_layers L:             {layers}")
    print(f"num_heads H:              {heads}")
    print(f"head_dim D/H:             {head_dim:.1f}")
    print(f"attention per layer:      {fmt_params(attention_per_layer)}")
    print(f"FFN per layer:            {fmt_params(ffn_per_layer)}")
    print(f"embedding:                {fmt_params(embedding)}")
    print(f"total params:             {fmt_params(total_params)}")
    print()
    print("Training state memory estimate")
    print("--------------------------------")
    print(f"parameter memory:         {gb(param_memory):.2f} GB")
    print(f"gradient memory:          {gb(grad_memory):.2f} GB")
    print(f"Adam state memory:        {gb(adam_memory):.2f} GB")
    print(f"replicated training state:{gb(training_state_memory):.2f} GB")
    print()
    print("Inference KV Cache estimate")
    print("----------------------------")
    print(f"batch_size:               {args.batch_size}")
    print(f"seq_len:                  {args.seq_len}")
    print(f"KV Cache memory:          {gb(kv_cache_bytes):.2f} GB")
    print()
    print("Notes")
    print("-----")
    print("- Training also needs activations, temporary workspaces and communication buffers.")
    print("- Inference removes gradients/optimizer states but adds KV Cache pressure.")
    print("- Real models differ by FFN variant, tied embeddings, bias/norm choices and GQA/MQA.")


if __name__ == "__main__":
    main()
