"""InfraLens PyTorch example: MHA/MQA/GQA/MLA KV cache shape and memory.

Educational clean-room implementation adapted from algorithms covered by
ARIS-in-AI-Offer (MIT License). CPU-safe tensor accounting plus a toy page
table for fragmented request lengths.
"""

from __future__ import annotations

import math

import torch


def kv_bytes(batch: int, seq: int, layers: int, kv_heads: int, head_dim: int, bytes_per_elem: int = 2) -> int:
    return 2 * batch * seq * layers * kv_heads * head_dim * bytes_per_elem


def compare_mha_mqa_gqa(batch: int = 4, seq: int = 2048, layers: int = 32, q_heads: int = 32, head_dim: int = 128) -> dict[str, float]:
    configs = {"MHA": q_heads, "GQA-8KV": 8, "MQA": 1}
    return {name: round(kv_bytes(batch, seq, layers, heads, head_dim) / 2**30, 3) for name, heads in configs.items()}


def toy_mla_cache(k: torch.Tensor, latent_dim: int) -> torch.Tensor:
    """Compress K [B,H,S,Dh] to latent [B,H,S,R]."""
    down = torch.randn(k.shape[-1], latent_dim) / math.sqrt(k.shape[-1])
    return k @ down


def paged_blocks(lengths: list[int], block_size: int) -> dict[int, list[int]]:
    table: dict[int, list[int]] = {}
    next_block = 0
    for req, length in enumerate(lengths):
        blocks = math.ceil(length / block_size)
        table[req] = list(range(next_block, next_block + blocks))
        next_block += blocks
    return table


def smoke_test() -> None:
    torch.manual_seed(0)
    memory = compare_mha_mqa_gqa()
    k = torch.randn(2, 8, 16, 64)
    latent = toy_mla_cache(k, latent_dim=12)
    table = paged_blocks([3, 9, 2], block_size=4)
    assert memory["MHA"] > memory["GQA-8KV"] > memory["MQA"]
    assert latent.shape == (2, 8, 16, 12)
    assert table[1] == [1, 2, 3]
    print("kv_cache_variants.py ok", {"memory_gb": memory, "latent_shape": tuple(latent.shape), "block_table": table})


if __name__ == "__main__":
    smoke_test()
