Jax rl example
Example of an RL algorithm (PPO) written entirely in Jax.
This is based on rejax.PPO
.
See the JaxRLExample
class for a description of the differences w.r.t. rejax.PPO
.
TEnvParams
module-attribute
#
Type variable for the env params (gymnax.EnvParams
).
Trajectory #
TrajectoryWithLastObs #
Bases: PyTreeNode
Trajectory with the last observation and whether the last step is the end of an episode.
AdvantageMinibatch #
TrajectoryCollectionState #
Bases: Generic[TEnvState]
, PyTreeNode
Struct containing the state related to the collection of data from the environment.
PPOState #
PPOHParams #
Bases: PyTreeNode
Hyper-parameters for this PPO example.
These are taken from rejax.PPO
algorithm class.
JaxRLExample #
Bases: PyTreeNode
, JaxModule[PPOState[TEnvState], TrajectoryWithLastObs, EvalMetrics]
, Generic[TEnvState, TEnvParams]
Example of an RL algorithm written in Jax: PPO, based on rejax.PPO
.
Differences w.r.t. rejax.PPO:#
- The state / hparams are split into different, fully-typed structs:
- The algorithm state is in a typed
PPOState
struct (vs an untyped, dynamically-generated struct in rejax). - The hyper-parameters are in a typed
PPOHParams
struct. - The state variables related to the collection of data from the environment is a
TrajectoryCollectionState
instead of everything being bunched up together.- This makes it easier to call the
collect_episodes
function with just what it needs.
- This makes it easier to call the
- The algorithm state is in a typed
- The seeds for the networks and the environment data collection are separated.
The logic is exactly the same: The losses / updates are computed in the exact same way.
training_step #
training_step(
batch_idx: int,
ts: PPOState[TEnvState],
batch: TrajectoryWithLastObs,
)
Training step in pure jax.
train #
train(
rng: Array,
train_state: PPOState[TEnvState] | None = None,
skip_initial_evaluation: bool = False,
) -> tuple[PPOState[TEnvState], EvalMetrics]
Full training loop in jax.
This is only here to match the API of rejax.PPO.train
. This doesn't get called when using
the JaxTrainer
, since JaxTrainer.fit
already does the same thing, but also with support
for some JaxCallback
s (as well as some lightning.Callback
s!).
Unfolded version of rejax.PPO.train
.