Jax ppo
Example of an RL algorithm (PPO) written entirely in Jax.
This is based on rejax.PPO.
See the JaxRLExample class for a description of the differences w.r.t. rejax.PPO.
TEnvParams
module-attribute
#
Type variable for the env params (gymnax.EnvParams).
Trajectory #
TrajectoryWithLastObs #
Bases: PyTreeNode
Trajectory with the last observation and whether the last step is the end of an episode.
AdvantageMinibatch #
TrajectoryCollectionState #
Bases: Generic[TEnvState], PyTreeNode
Struct containing the state related to the collection of data from the environment.
PPOState #
PPOHParams #
Bases: PyTreeNode
Hyper-parameters for this PPO example.
These are taken from rejax.PPO algorithm class.
JaxRLExample #
Bases: PyTreeNode, JaxModule[PPOState[TEnvState], TrajectoryWithLastObs, EvalMetrics], Generic[TEnvState, TEnvParams]
Example of an RL algorithm written in Jax: PPO, based on rejax.PPO.
Differences w.r.t. rejax.PPO:#
- The state / hparams are split into different, fully-typed structs:
- The algorithm state is in a typed
PPOStatestruct (vs an untyped, dynamically-generated struct in rejax). - The hyper-parameters are in a typed
PPOHParamsstruct. - The state variables related to the collection of data from the environment is a
TrajectoryCollectionStateinstead of everything being bunched up together.- This makes it easier to call the
collect_episodesfunction with just what it needs.
- This makes it easier to call the
- The algorithm state is in a typed
- The seeds for the networks and the environment data collection are separated.
The logic is exactly the same: The losses / updates are computed in the exact same way.
training_step #
training_step(
batch_idx: int,
ts: PPOState[TEnvState],
batch: TrajectoryWithLastObs,
)
Training step in pure jax.
train #
train(
rng: Array,
train_state: PPOState[TEnvState] | None = None,
skip_initial_evaluation: Static[bool] = False,
) -> tuple[PPOState[TEnvState], EvalMetrics]
Full training loop in jax.
This is only here to match the API of rejax.PPO.train. This doesn't get called when using
the JaxTrainer, since JaxTrainer.fit already does the same thing, but also with support
for some JaxCallbacks (as well as some lightning.Callbacks!).
Unfolded version of rejax.PPO.train.
field #
field(
*,
default: T | _MISSING_TYPE = MISSING,
default_factory: (
Callable[[], T] | _MISSING_TYPE
) = MISSING,
init=True,
repr=True,
hash=None,
compare=True,
metadata: Mapping[Any, Any] | None = None,
kw_only=MISSING,
pytree_node: bool | None = None
) -> T
Small Typing fix for flax.struct.field.
- Add type annotations so it doesn't drop the signature of the
dataclasses.fieldfunction. - Make the
pytree_nodehas a default value ofFalsefor ints and bools, andTruefor everything else.