Skip to content

Jax rl example

Example of an RL algorithm (PPO) written entirely in Jax.

This is based on rejax.PPO. See the JaxRLExample class for a description of the differences w.r.t. rejax.PPO.

TEnvParams module-attribute #

TEnvParams = TypeVar(
    "TEnvParams", bound=EnvParams, default=EnvParams
)

Type variable for the env params (gymnax.EnvParams).

Trajectory #

Bases: PyTreeNode

A sequence of interactions between an agent and an environment.

TrajectoryWithLastObs #

Bases: PyTreeNode

Trajectory with the last observation and whether the last step is the end of an episode.

AdvantageMinibatch #

Bases: PyTreeNode

Annotated trajectories with advantages and targets for the critic.

TrajectoryCollectionState #

Bases: Generic[TEnvState], PyTreeNode

Struct containing the state related to the collection of data from the environment.

PPOState #

Bases: Generic[TEnvState], PyTreeNode

Contains all the state of the JaxRLExample algorithm.

PPOHParams #

Bases: PyTreeNode

Hyper-parameters for this PPO example.

These are taken from rejax.PPO algorithm class.

JaxRLExample #

Bases: PyTreeNode, JaxModule[PPOState[TEnvState], TrajectoryWithLastObs, EvalMetrics], Generic[TEnvState, TEnvParams]

Example of an RL algorithm written in Jax: PPO, based on rejax.PPO.

Differences w.r.t. rejax.PPO:#

  • The state / hparams are split into different, fully-typed structs:
    • The algorithm state is in a typed PPOState struct (vs an untyped, dynamically-generated struct in rejax).
    • The hyper-parameters are in a typed PPOHParams struct.
    • The state variables related to the collection of data from the environment is a TrajectoryCollectionState instead of everything being bunched up together.
      • This makes it easier to call the collect_episodes function with just what it needs.
  • The seeds for the networks and the environment data collection are separated.

The logic is exactly the same: The losses / updates are computed in the exact same way.

training_step #

training_step(
    batch_idx: int,
    ts: PPOState[TEnvState],
    batch: TrajectoryWithLastObs,
)

Training step in pure jax.

train #

train(
    rng: Array,
    train_state: PPOState[TEnvState] | None = None,
    skip_initial_evaluation: bool = False,
) -> tuple[PPOState[TEnvState], EvalMetrics]

Full training loop in jax.

This is only here to match the API of rejax.PPO.train. This doesn't get called when using the JaxTrainer, since JaxTrainer.fit already does the same thing, but also with support for some JaxCallbacks (as well as some lightning.Callbacks!).

Unfolded version of rejax.PPO.train.

fused_training_step #

fused_training_step(
    iteration: int, ts: PPOState[TEnvState]
)

Fused training step in jax (joined data collection + training).

MUCH faster than using pytorch-lightning, but you lose the callbacks and such.