Jax trainer
Ts
module-attribute
#
Ts = TypeVar('Ts', bound=PyTreeNode, default=PyTreeNode)
Type Variable for the training state.
JaxModule #
Bases: Protocol[Ts, _B, _MetricsT]
A protocol for algorithms that can be trained by the JaxTrainer
.
The JaxRLExample
is an example that follows this structure and can be trained with a
JaxTrainer
.
JaxTrainer #
Bases: PyTreeNode
A simplified version of the lightning.Trainer
with a fully jitted training loop.
Assumptions:#
- The algo object must match the
JaxModule
protocol (in other words, it should implement its methods).
Training loop#
This is the training loop, which is fully jitted:
ts = algo.init_train_state(rng)
setup("fit")
on_fit_start()
on_train_start()
eval_metrics = []
for epoch in range(self.max_epochs):
on_train_epoch_start()
for step in range(self.training_steps_per_epoch):
batch = algo.get_batch(ts, step)
on_train_batch_start()
ts, metrics = algo.training_step(step, ts, batch)
on_train_batch_end()
on_train_epoch_end()
# Evaluation "loop"
on_validation_epoch_start()
epoch_eval_metrics = self.eval_epoch(ts, epoch, algo)
on_validation_epoch_start()
eval_metrics.append(epoch_eval_metrics)
return ts, eval_metrics
Caveats#
- Some lightning callbacks can be used with this trainer and work well, but not all of them.
- You can either use Regular pytorch-lightning callbacks, or use
jax.vmap
on thefit
method, but not both. - If you want to use jax.vmap on the
fit
method, just remove the callbacks on the Trainer for now.
TODOs / ideas#
- Add a checkpoint callback with orbax-checkpoint?
fit #
fit(
algo: JaxModule[Ts, _B, _MetricsT],
rng: PRNGKey,
train_state: Ts | None = None,
skip_initial_evaluation: bool = False,
) -> tuple[Ts, _MetricsT]
Full training loop in pure jax (a lot faster than when using pytorch-lightning).
Unfolded version of rejax.PPO.train
.
Training loop in pure jax (a lot faster than when using pytorch-lightning).
hparams_to_dict #
hparams_to_dict(algo: PyTreeNode) -> dict
Convert the learner struct to a serializable dict.