"""InfraLens PyTorch example: decoding strategies for LLM whiteboarding.

Clean-room educational implementation of ordinary next-token decoding policy:
greedy selection, temperature, top-k, top-p, and one beam-search expansion.
It is intentionally separate from speculative decoding verification.
"""

from __future__ import annotations

import torch


def temperature_probs(logits: torch.Tensor, temperature: float = 1.0) -> torch.Tensor:
    """Turn logits into probabilities after controlling sharpness."""
    assert temperature > 0.0
    return torch.softmax(logits / temperature, dim=-1)


def greedy_token(logits: torch.Tensor) -> int:
    """Choose the maximum-logit token without sampling."""
    return int(torch.argmax(logits, dim=-1).item())


def top_k_logits(logits: torch.Tensor, k: int) -> torch.Tensor:
    """Keep only the k largest logits."""
    assert 0 < k <= logits.shape[-1]
    threshold = torch.topk(logits, k, dim=-1).values[..., -1, None]
    return logits.masked_fill(logits < threshold, float("-inf"))


def top_p_logits(logits: torch.Tensor, p: float) -> torch.Tensor:
    """Keep the smallest sorted prefix whose probability mass reaches p."""
    assert 0.0 < p <= 1.0
    sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
    sorted_probs = torch.softmax(sorted_logits, dim=-1)
    cumulative = sorted_probs.cumsum(dim=-1)
    remove = cumulative - sorted_probs >= p
    filtered_sorted = sorted_logits.masked_fill(remove, float("-inf"))
    return torch.empty_like(logits).scatter(-1, sorted_indices, filtered_sorted)


def sample_token(logits: torch.Tensor, seed: int) -> int:
    """Sample deterministically for a fixed teaching seed."""
    generator = torch.Generator().manual_seed(seed)
    return int(torch.multinomial(torch.softmax(logits, dim=-1), 1, generator=generator).item())


def beam_search_step(
    beam_scores: torch.Tensor,
    next_log_probs: torch.Tensor,
    beam_width: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Expand one beam-search step and retain globally best candidates."""
    candidates = beam_scores[:, None] + next_log_probs
    scores, flat_indices = torch.topk(candidates.flatten(), beam_width)
    vocab_size = next_log_probs.shape[-1]
    parent_beams = torch.div(flat_indices, vocab_size, rounding_mode="floor")
    token_ids = flat_indices % vocab_size
    return parent_beams, token_ids, scores


def smoke_test() -> None:
    logits = torch.tensor([3.0, 2.0, 1.0, 0.0])
    assert greedy_token(logits) == 0

    cold = temperature_probs(logits, temperature=0.5)
    warm = temperature_probs(logits, temperature=2.0)
    assert cold[0] > warm[0]
    assert torch.allclose(cold.sum(), torch.tensor(1.0))

    filtered_k = top_k_logits(logits, k=2)
    assert torch.isneginf(filtered_k[2:]).all()

    filtered_p = top_p_logits(logits, p=0.75)
    assert torch.isfinite(filtered_p).sum().item() == 2
    assert sample_token(filtered_p, seed=17) == sample_token(filtered_p, seed=17)

    next_log_probs = torch.log_softmax(torch.tensor([[4.0, 2.0, 1.0], [1.0, 5.0, 0.0]]), dim=-1)
    parents, tokens, scores = beam_search_step(torch.tensor([0.0, -0.3]), next_log_probs, beam_width=2)
    assert parents.tolist() == [0, 1]
    assert tokens.tolist() == [0, 1]
    print(
        "decoding_strategies_whiteboard.py ok",
        {
            "greedy": greedy_token(logits),
            "top_k_kept": int(torch.isfinite(filtered_k).sum().item()),
            "top_p_kept": int(torch.isfinite(filtered_p).sum().item()),
            "beam_tokens": tokens.tolist(),
            "beam_scores": [round(float(score), 4) for score in scores],
        },
    )


if __name__ == "__main__":
    smoke_test()
