"""InfraLens PyTorch example: LoRA for a decoder projection.

Clean-room educational implementation of a frozen linear weight plus a
trainable low-rank update. The smoke test checks shape, parameter count, and
merged-weight equivalence without training a model.
"""

from __future__ import annotations

import torch
import torch.nn as nn


class LoRALinear(nn.Module):
    """Apply y = x W^T + scale * x A^T B^T with a frozen base weight."""

    def __init__(self, in_features: int, out_features: int, rank: int, alpha: float) -> None:
        super().__init__()
        assert 0 < rank <= min(in_features, out_features)
        self.base = nn.Linear(in_features, out_features, bias=False)
        self.base.weight.requires_grad_(False)
        self.a = nn.Linear(in_features, rank, bias=False)
        self.b = nn.Linear(rank, out_features, bias=False)
        self.scale = alpha / rank
        nn.init.zeros_(self.b.weight)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.base(x) + self.scale * self.b(self.a(x))

    def merged_weight(self) -> torch.Tensor:
        """Return the deployment-time merged weight W + scale * B A."""
        return self.base.weight + self.scale * (self.b.weight @ self.a.weight)

    def trainable_parameter_count(self) -> int:
        return sum(parameter.numel() for parameter in self.parameters() if parameter.requires_grad)


def smoke_test() -> None:
    torch.manual_seed(11)
    batch, seq, hidden, rank = 2, 5, 16, 4
    projection = LoRALinear(hidden, hidden, rank=rank, alpha=8.0)
    x = torch.randn(batch, seq, hidden)

    base_only = projection.base(x)
    initially_adapted = projection(x)
    assert torch.allclose(initially_adapted, base_only, atol=1e-6)

    with torch.no_grad():
        projection.b.weight.normal_(mean=0.0, std=0.02)

    adapted = projection(x)
    merged = torch.nn.functional.linear(x, projection.merged_weight())
    expected_trainable = hidden * rank + rank * hidden

    assert adapted.shape == x.shape
    assert not projection.base.weight.requires_grad
    assert projection.trainable_parameter_count() == expected_trainable
    assert torch.allclose(adapted, merged, atol=1e-6)
    assert not torch.allclose(adapted, base_only)
    print(
        "lora_finetuning_whiteboard.py ok",
        {
            "output_shape": tuple(adapted.shape),
            "rank": rank,
            "scale": projection.scale,
            "trainable": projection.trainable_parameter_count(),
            "base_frozen": True,
            "initial_noop": True,
            "merge_equivalent": True,
        },
    )


if __name__ == "__main__":
    smoke_test()
