"""InfraLens PyTorch example: top-k MoE routing and expert dispatch.

Educational clean-room implementation adapted from the algorithms covered by
ARIS-in-AI-Offer (MIT License). Runs a tiny sparse FFN layer on CPU.
"""

from __future__ import annotations

import torch
import torch.nn as nn
import torch.nn.functional as F


class TinyMoE(nn.Module):
    def __init__(self, hidden_dim: int = 8, expert_dim: int = 16, num_experts: int = 4, top_k: int = 2) -> None:
        super().__init__()
        self.num_experts = num_experts
        self.top_k = top_k
        self.router = nn.Linear(hidden_dim, num_experts, bias=False)
        self.experts = nn.ModuleList(
            [nn.Sequential(nn.Linear(hidden_dim, expert_dim), nn.GELU(), nn.Linear(expert_dim, hidden_dim)) for _ in range(num_experts)]
        )

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
        # x: [tokens, hidden_dim]
        logits = self.router(x)
        weights, indices = torch.topk(torch.softmax(logits, dim=-1), self.top_k, dim=-1)
        weights = weights / weights.sum(dim=-1, keepdim=True)

        out = torch.zeros_like(x)
        load = torch.zeros(self.num_experts, device=x.device)
        for expert_id, expert in enumerate(self.experts):
            token_pos, slot_pos = torch.where(indices == expert_id)
            if token_pos.numel() == 0:
                continue
            expert_out = expert(x[token_pos])
            out[token_pos] += expert_out * weights[token_pos, slot_pos].unsqueeze(-1)
            load[expert_id] = token_pos.numel()

        probs = torch.softmax(logits, dim=-1).mean(dim=0)
        load_fraction = load / load.sum().clamp_min(1.0)
        aux_loss = self.num_experts * torch.sum(probs * load_fraction)
        return out, {"router_logits": logits, "expert_load": load, "aux_loss": aux_loss}


def smoke_test() -> None:
    torch.manual_seed(0)
    x = torch.randn(10, 8)
    moe = TinyMoE()
    y, info = moe(x)
    assert y.shape == x.shape
    assert info["expert_load"].sum().item() == 10 * moe.top_k
    print("moe_routing.py ok", {"output_shape": tuple(y.shape), "expert_load": info["expert_load"].tolist(), "aux_loss": round(info["aux_loss"].item(), 4)})


if __name__ == "__main__":
    smoke_test()
