"""InfraLens PyTorch example: KL, DPO, PPO-style shaping, and GRPO advantage.

Educational clean-room implementation adapted from the algorithms covered by
ARIS-in-AI-Offer (MIT License). Uses toy logits and rewards; no model download.
"""

from __future__ import annotations

import torch
import torch.nn.functional as F


def token_kl(policy_logp: torch.Tensor, ref_logp: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
    """Sequence KL estimate over sampled tokens."""
    return ((policy_logp - ref_logp) * mask).sum(dim=-1) / mask.sum(dim=-1).clamp_min(1)


def ppo_reward_with_kl(reward: torch.Tensor, policy_logp: torch.Tensor, ref_logp: torch.Tensor, mask: torch.Tensor, beta: float) -> torch.Tensor:
    return reward - beta * token_kl(policy_logp, ref_logp, mask)


def dpo_loss(chosen_logp: torch.Tensor, rejected_logp: torch.Tensor, ref_chosen_logp: torch.Tensor, ref_rejected_logp: torch.Tensor, beta: float = 0.1) -> torch.Tensor:
    policy_margin = chosen_logp - rejected_logp
    ref_margin = ref_chosen_logp - ref_rejected_logp
    return -F.logsigmoid(beta * (policy_margin - ref_margin)).mean()


def grpo_advantages(rewards: torch.Tensor, group_size: int) -> torch.Tensor:
    """Group-relative normalization for rewards shaped [prompts, group_size]."""
    grouped = rewards.view(-1, group_size)
    mean = grouped.mean(dim=1, keepdim=True)
    std = grouped.std(dim=1, keepdim=True).clamp_min(1e-6)
    return ((grouped - mean) / std).reshape_as(rewards)


def smoke_test() -> None:
    torch.manual_seed(0)
    mask = torch.tensor([[1, 1, 1, 0], [1, 1, 0, 0]], dtype=torch.float32)
    policy_logp = torch.randn(2, 4) - 2.0
    ref_logp = policy_logp - 0.1 * torch.randn(2, 4)
    shaped = ppo_reward_with_kl(torch.tensor([1.0, 0.5]), policy_logp, ref_logp, mask, beta=0.05)

    chosen = torch.tensor([[-4.0], [-3.0]])
    rejected = torch.tensor([[-5.0], [-3.5]])
    ref_chosen = torch.tensor([[-4.2], [-3.2]])
    ref_rejected = torch.tensor([[-4.8], [-3.4]])
    loss = dpo_loss(chosen, rejected, ref_chosen, ref_rejected)
    adv = grpo_advantages(torch.tensor([1.0, 2.0, 0.0, 3.0]), group_size=2)
    print("rlhf_dpo_grpo.py ok", {"shaped_reward": shaped.tolist(), "dpo_loss": round(loss.item(), 4), "grpo_adv": adv.tolist()})


if __name__ == "__main__":
    smoke_test()
