Skip to content

Jax ppo test

PPOLightningModule #

Bases: LightningModule

Uses the same code as JaxRLExample, but the training loop is run with pytorch-lightning.

This is currently only meant to be used to compare the difference fully-jitted training loop and lightning.

RlThroughputCallback #

Bases: MeasureSamplesPerSecondCallback

A callback to measure the throughput of RL algorithms.

test_rejax #

test_rejax(
    rng: PRNGKey,
    results_rejax: tuple[PPO, Any, EvalMetrics],
    tensor_regression: TensorRegressionFixture,
    original_datadir: Path,
    seed: int | Sequence[int],
)

Train rejax.PPO with the same parameters.