"""InfraLens PyTorch example: KV cache for autoregressive decode.

Educational clean-room implementation adapted from the algorithms covered by
ARIS-in-AI-Offer (MIT License). Demonstrates full causal attention versus
incremental decode with cached K/V tensors.
"""

from __future__ import annotations

import math
from dataclasses import dataclass

import torch
import torch.nn as nn


@dataclass
class KVCache:
    k: torch.Tensor | None = None  # [B, H_kv, T, Dh]
    v: torch.Tensor | None = None  # [B, H_kv, T, Dh]

    def append(self, k_new: torch.Tensor, v_new: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        self.k = k_new if self.k is None else torch.cat([self.k, k_new], dim=2)
        self.v = v_new if self.v is None else torch.cat([self.v, v_new], dim=2)
        return self.k, self.v


def repeat_kv(x: torch.Tensor, num_query_heads: int) -> torch.Tensor:
    """Repeat MQA/GQA KV heads so they can be read by query heads."""
    bsz, kv_heads, seq, head_dim = x.shape
    repeats = num_query_heads // kv_heads
    return x[:, :, None].expand(bsz, kv_heads, repeats, seq, head_dim).reshape(bsz, num_query_heads, seq, head_dim)


class TinyCachedSelfAttention(nn.Module):
    """Single attention layer with optional grouped-query KV heads."""

    def __init__(self, hidden_dim: int = 16, num_heads: int = 4, num_kv_heads: int = 2) -> None:
        super().__init__()
        assert hidden_dim % num_heads == 0
        assert num_heads % num_kv_heads == 0
        self.num_heads = num_heads
        self.num_kv_heads = num_kv_heads
        self.head_dim = hidden_dim // num_heads
        self.q_proj = nn.Linear(hidden_dim, num_heads * self.head_dim, bias=False)
        self.k_proj = nn.Linear(hidden_dim, num_kv_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(hidden_dim, num_kv_heads * self.head_dim, bias=False)
        self.out = nn.Linear(hidden_dim, hidden_dim, bias=False)

    def _shape(self, x: torch.Tensor, heads: int) -> torch.Tensor:
        bsz, seq, _ = x.shape
        return x.view(bsz, seq, heads, self.head_dim).transpose(1, 2).contiguous()

    def forward(self, x: torch.Tensor, cache: KVCache | None = None) -> tuple[torch.Tensor, KVCache | None]:
        # x is [B, S, D]. During decode S is usually 1.
        q = self._shape(self.q_proj(x), self.num_heads)
        k_new = self._shape(self.k_proj(x), self.num_kv_heads)
        v_new = self._shape(self.v_proj(x), self.num_kv_heads)

        if cache is not None:
            k, v = cache.append(k_new, v_new)
        else:
            k, v = k_new, v_new

        k = repeat_kv(k, self.num_heads)
        v = repeat_kv(v, self.num_heads)
        scores = q @ k.transpose(-2, -1) / math.sqrt(self.head_dim)

        if cache is None:
            seq = x.shape[1]
            causal = torch.ones(seq, seq, dtype=torch.bool, device=x.device).tril().view(1, 1, seq, seq)
            scores = scores.masked_fill(~causal, torch.finfo(scores.dtype).min)

        out = torch.softmax(scores, dim=-1) @ v
        bsz, heads, seq, head_dim = out.shape
        out = out.transpose(1, 2).contiguous().view(bsz, seq, heads * head_dim)
        return self.out(out), cache


def kv_cache_bytes(batch: int, seq_len: int, layers: int, kv_heads: int, head_dim: int, bytes_per_elem: int) -> int:
    return 2 * batch * seq_len * layers * kv_heads * head_dim * bytes_per_elem


def smoke_test() -> None:
    torch.manual_seed(0)
    model = TinyCachedSelfAttention()
    x = torch.randn(1, 5, 16)
    full, _ = model(x, cache=None)

    cache = KVCache()
    pieces = []
    for t in range(x.shape[1]):
        y, _ = model(x[:, t : t + 1], cache=cache)
        pieces.append(y)
    incremental = torch.cat(pieces, dim=1)

    assert torch.allclose(full, incremental, atol=1e-5)
    bytes_est = kv_cache_bytes(batch=4, seq_len=2048, layers=32, kv_heads=8, head_dim=128, bytes_per_elem=2)
    print("kv_cache.py ok", {"match": True, "cache_shape": tuple(cache.k.shape), "bytes_estimate_mb": round(bytes_est / 2**20, 2)})


if __name__ == "__main__":
    smoke_test()
