"""InfraLens PyTorch example: Transformer whiteboard primitives.

Clean-room educational implementation inspired by common LLM live-coding
interview topics. It keeps the code compact enough to rewrite on a whiteboard:
RMSNorm, SwiGLU, causal self-attention, and a tiny decoder block.
"""

from __future__ import annotations

import math

import torch
import torch.nn as nn
import torch.nn.functional as F


class RMSNorm(nn.Module):
    """Root-mean-square normalization over the hidden dimension."""

    def __init__(self, hidden_dim: int, eps: float = 1e-6) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_dim))
        self.eps = eps

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        rms = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).rsqrt()
        return x * rms * self.weight


class SwiGLU(nn.Module):
    """Gated feed-forward block: down(silu(gate(x)) * up(x))."""

    def __init__(self, hidden_dim: int, ffn_dim: int) -> None:
        super().__init__()
        self.gate = nn.Linear(hidden_dim, ffn_dim, bias=False)
        self.up = nn.Linear(hidden_dim, ffn_dim, bias=False)
        self.down = nn.Linear(ffn_dim, hidden_dim, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.down(F.silu(self.gate(x)) * self.up(x))


class CausalSelfAttention(nn.Module):
    """Minimal multi-head self-attention with a causal mask."""

    def __init__(self, hidden_dim: int, num_heads: int) -> None:
        super().__init__()
        assert hidden_dim % num_heads == 0
        self.num_heads = num_heads
        self.head_dim = hidden_dim // num_heads
        self.qkv = nn.Linear(hidden_dim, 3 * hidden_dim, bias=False)
        self.out = nn.Linear(hidden_dim, hidden_dim, bias=False)

    def split_heads(self, x: torch.Tensor) -> torch.Tensor:
        bsz, seq, dim = x.shape
        return x.view(bsz, seq, self.num_heads, dim // self.num_heads).transpose(1, 2)

    def merge_heads(self, x: torch.Tensor) -> torch.Tensor:
        bsz, heads, seq, head_dim = x.shape
        return x.transpose(1, 2).contiguous().view(bsz, seq, heads * head_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        bsz, seq, _ = x.shape
        q, k, v = self.qkv(x).chunk(3, dim=-1)
        q, k, v = map(self.split_heads, (q, k, v))

        scores = q @ k.transpose(-2, -1) / math.sqrt(self.head_dim)
        causal = torch.ones(seq, seq, dtype=torch.bool, device=x.device).tril()
        scores = scores.masked_fill(~causal.view(1, 1, seq, seq), torch.finfo(scores.dtype).min)
        context = torch.softmax(scores, dim=-1) @ v
        assert context.shape == (bsz, self.num_heads, seq, self.head_dim)
        return self.out(self.merge_heads(context))


class TinyDecoderBlock(nn.Module):
    """Pre-norm decoder block: x + attention, then x + SwiGLU."""

    def __init__(self, hidden_dim: int = 16, num_heads: int = 4, ffn_dim: int = 32) -> None:
        super().__init__()
        self.attn_norm = RMSNorm(hidden_dim)
        self.attn = CausalSelfAttention(hidden_dim, num_heads)
        self.ffn_norm = RMSNorm(hidden_dim)
        self.ffn = SwiGLU(hidden_dim, ffn_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.attn(self.attn_norm(x))
        x = x + self.ffn(self.ffn_norm(x))
        return x


def assert_causal_no_leak(attn: CausalSelfAttention, x: torch.Tensor, prefix_len: int = 3) -> None:
    """Changing future tokens must not change earlier attention outputs."""
    changed = x.clone()
    changed[:, prefix_len:] = torch.randn_like(changed[:, prefix_len:]) * 10.0

    base = attn(x)
    future_changed = attn(changed)
    assert torch.allclose(base[:, :prefix_len], future_changed[:, :prefix_len], atol=1e-5)


def smoke_test() -> None:
    torch.manual_seed(7)
    x = torch.randn(2, 5, 16)
    block = TinyDecoderBlock()
    y = block(x)

    normed = RMSNorm(16)(x)
    ffn_out = SwiGLU(16, 32)(x)
    attn = CausalSelfAttention(16, 4)
    attn_out = attn(x)
    assert_causal_no_leak(attn, x)

    assert y.shape == x.shape
    assert normed.shape == x.shape
    assert ffn_out.shape == x.shape
    assert attn_out.shape == x.shape
    assert torch.isfinite(y).all()
    print(
        "transformer_whiteboard_primitives.py ok",
        {
            "input": tuple(x.shape),
            "decoder_output": tuple(y.shape),
            "heads": block.attn.num_heads,
            "ffn_dim": block.ffn.up.out_features,
            "causal_no_leak": True,
        },
    )


if __name__ == "__main__":
    smoke_test()
