"""InfraLens Python example: reference answers for resource estimation drills."""


def sharded_state_gb(params_b: float, dp: int, stage: int) -> float:
    params, grads, optimizer = params_b * 2, params_b * 2, params_b * 8
    per_rank = (params / dp if stage >= 3 else params)
    per_rank += grads / dp if stage >= 2 else grads
    per_rank += optimizer / dp if stage >= 1 else optimizer
    return per_rank


def ring_allreduce_payload_gb(tensor_gb: float, world: int) -> float:
    return 2 * (world - 1) / world * tensor_gb


def pipeline_bubble_fraction(stages: int, microbatches: int) -> float:
    return (stages - 1) / (microbatches + stages - 1)


def smoke_test() -> None:
    assert sharded_state_gb(params_b=7, dp=8, stage=3) == 10.5
    assert ring_allreduce_payload_gb(tensor_gb=2, world=8) == 3.5
    assert round(pipeline_bubble_fraction(stages=4, microbatches=16), 4) == 0.1579
    print(
        "performance_estimation_drills.py ok",
        {"zero3_gb": 10.5, "ring_payload_gb": 3.5, "bubble_pct": 15.79},
    )


if __name__ == "__main__":
    smoke_test()
