"""InfraLens PyTorch example: RoPE, position interpolation, and NTK-style scaling.

Educational clean-room implementation adapted from algorithms covered by
ARIS-in-AI-Offer (MIT License). This is a CPU-safe tensor demo, not a model.
"""

from __future__ import annotations

import torch


def rope_frequencies(head_dim: int, base: float = 10000.0) -> torch.Tensor:
    """Return inverse frequencies for paired RoPE dimensions."""
    assert head_dim % 2 == 0
    return 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))


def apply_rope(x: torch.Tensor, positions: torch.Tensor, inv_freq: torch.Tensor) -> torch.Tensor:
    """Apply RoPE to x shaped [B, H, S, Dh]."""
    angles = positions.float()[:, None] * inv_freq[None, :]  # [S, Dh/2]
    cos = angles.cos()[None, None, :, :]
    sin = angles.sin()[None, None, :, :]
    x_even, x_odd = x[..., 0::2], x[..., 1::2]
    rotated_even = x_even * cos - x_odd * sin
    rotated_odd = x_even * sin + x_odd * cos
    return torch.stack([rotated_even, rotated_odd], dim=-1).flatten(-2)


def position_interpolation_positions(seq_len: int, train_context: int, target_context: int) -> torch.Tensor:
    """Compress long positions back into the training range."""
    scale = train_context / float(target_context)
    return torch.arange(seq_len).float() * scale


def ntk_scaled_base(base: float, scale: float, head_dim: int) -> float:
    """Simple NTK-aware base scaling used as an educational approximation."""
    if scale <= 1.0:
        return base
    return base * (scale ** (head_dim / max(head_dim - 2, 1)))


def sliding_window_mask(seq_len: int, window: int, sink_tokens: int = 0) -> torch.Tensor:
    """StreamingLLM-style keep mask with optional attention sinks."""
    i = torch.arange(seq_len)[:, None]
    j = torch.arange(seq_len)[None, :]
    causal = j <= i
    recent = j >= (i - window + 1)
    sink = j < sink_tokens
    return causal & (recent | sink)


def smoke_test() -> None:
    torch.manual_seed(0)
    q = torch.randn(1, 2, 8, 6)  # [batch, heads, seq, head_dim]
    inv = rope_frequencies(head_dim=6)
    normal = apply_rope(q, torch.arange(8), inv)
    interp_pos = position_interpolation_positions(seq_len=8, train_context=4, target_context=8)
    interpolated = apply_rope(q, interp_pos, inv)
    ntk_inv = rope_frequencies(6, base=ntk_scaled_base(10000.0, scale=2.0, head_dim=6))
    ntk = apply_rope(q, torch.arange(8), ntk_inv)
    mask = sliding_window_mask(seq_len=8, window=3, sink_tokens=1)
    assert normal.shape == interpolated.shape == ntk.shape == q.shape
    assert mask.shape == (8, 8) and mask[7, 0] and not mask[7, 3]
    print("long_context_rope.py ok", {"shape": tuple(normal.shape), "pi_last_pos": round(interp_pos[-1].item(), 3), "mask_row7": mask[7].int().tolist()})


if __name__ == "__main__":
    smoke_test()
