Skip to content

Jax rl example test

PPOLightningModule #

Bases: LightningModule

Uses the same code as project.algorithms.jax_rl_example.JaxRLExample, but the training loop is run with pytorch-lightning.

This is currently only meant to be used to compare the difference fully-jitted training loop and lightning.

RlThroughputCallback #

Bases: MeasureSamplesPerSecondCallback

A callback to measure the throughput of RL algorithms.

test_rejax #

test_rejax(
    rng: PRNGKey,
    results_rejax: tuple[PPO, Any, EvalMetrics],
    tensor_regression: TensorRegressionFixture,
    original_datadir: Path,
    n_agents: int | None,
)

Train rejax.PPO with the same parameters.