"""InfraLens PyTorch example: scaled dot-product and multi-head attention.

Educational clean-room implementation adapted from the algorithms covered by
ARIS-in-AI-Offer (MIT License). This file is intentionally small: it shows
Q/K/V shapes, causal masking, padding masks, and the all-masked-row NaN trap.
"""

from __future__ import annotations

import math

import torch
import torch.nn as nn
import torch.nn.functional as F


def safe_masked_softmax(scores: torch.Tensor, mask: torch.Tensor | None) -> torch.Tensor:
    """Softmax that returns zeros for rows where every key is masked.

    scores: [B, H, Q, K]
    mask: broadcastable bool tensor where True means "keep".
    """
    if mask is None:
        return torch.softmax(scores, dim=-1)

    keep = mask.to(torch.bool)
    masked = scores.masked_fill(~keep, torch.finfo(scores.dtype).min)
    all_masked = ~keep.any(dim=-1, keepdim=True)
    masked = masked.masked_fill(all_masked, 0.0)
    probs = torch.softmax(masked, dim=-1)
    return probs.masked_fill(all_masked, 0.0)


def scaled_dot_product_attention(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    mask: torch.Tensor | None = None,
) -> torch.Tensor:
    """Attention(Q,K,V) = softmax(QK^T / sqrt(d_k) + mask) V.

    q, k, v: [B, H, S, Dh]
    mask: optional [B, 1, S, S] or broadcastable keep mask.
    """
    d_k = q.shape[-1]
    scores = q @ k.transpose(-2, -1) / math.sqrt(d_k)
    probs = safe_masked_softmax(scores, mask)
    return probs @ v


class TinyMultiHeadAttention(nn.Module):
    """Minimal MHA block that keeps shape changes explicit."""

    def __init__(self, hidden_dim: int, num_heads: int) -> None:
        super().__init__()
        assert hidden_dim % num_heads == 0
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.head_dim = hidden_dim // num_heads
        self.qkv = nn.Linear(hidden_dim, 3 * hidden_dim, bias=False)
        self.out = nn.Linear(hidden_dim, hidden_dim, bias=False)

    def split_heads(self, x: torch.Tensor) -> torch.Tensor:
        bsz, seq, dim = x.shape
        return x.view(bsz, seq, self.num_heads, dim // self.num_heads).transpose(1, 2)

    def merge_heads(self, x: torch.Tensor) -> torch.Tensor:
        bsz, heads, seq, head_dim = x.shape
        return x.transpose(1, 2).contiguous().view(bsz, seq, heads * head_dim)

    def forward(self, x: torch.Tensor, causal: bool = True) -> torch.Tensor:
        # x: [B, S, D]
        q, k, v = self.qkv(x).chunk(3, dim=-1)
        q, k, v = map(self.split_heads, (q, k, v))  # [B, H, S, Dh]

        mask = None
        if causal:
            seq = x.shape[1]
            mask = torch.ones(seq, seq, dtype=torch.bool, device=x.device).tril()
            mask = mask.view(1, 1, seq, seq)

        context = scaled_dot_product_attention(q, k, v, mask)
        return self.out(self.merge_heads(context))


def smoke_test() -> None:
    torch.manual_seed(0)
    x = torch.randn(2, 4, 8)  # [batch=2, seq=4, hidden=8]
    mha = TinyMultiHeadAttention(hidden_dim=8, num_heads=2)
    y = mha(x)
    assert y.shape == x.shape

    scores = torch.randn(1, 1, 2, 3)
    mask = torch.tensor([[[[False, False, False], [True, False, True]]]])
    probs = safe_masked_softmax(scores, mask)
    assert torch.allclose(probs[0, 0, 0], torch.zeros(3))
    assert torch.allclose(probs[0, 0, 1].sum(), torch.tensor(1.0))
    print("attention.py ok", {"output_shape": tuple(y.shape), "masked_row": probs[0, 0, 0].tolist()})


if __name__ == "__main__":
    smoke_test()
