Trainers
JaxTrainer #
Bases: PyTreeNode
A simplified version of the lightning.Trainer
with a fully jitted training loop.
Assumptions:#
- The algo object must match the
JaxModule
protocol (in other words, it should implement its methods).
Training loop#
This is the training loop, which is fully jitted:
ts = algo.init_train_state(rng)
setup("fit")
on_fit_start()
on_train_start()
eval_metrics = []
for epoch in range(self.max_epochs):
on_train_epoch_start()
for step in range(self.training_steps_per_epoch):
batch = algo.get_batch(ts, step)
on_train_batch_start()
ts, metrics = algo.training_step(step, ts, batch)
on_train_batch_end()
on_train_epoch_end()
# Evaluation "loop"
on_validation_epoch_start()
epoch_eval_metrics = self.eval_epoch(ts, epoch, algo)
on_validation_epoch_start()
eval_metrics.append(epoch_eval_metrics)
return ts, eval_metrics
Caveats#
- Some lightning callbacks can be used with this trainer and work well, but not all of them.
- You can either use Regular pytorch-lightning callbacks, or use
jax.vmap
on thefit
method, but not both. - If you want to use jax.vmap on the
fit
method, just remove the callbacks on the Trainer for now.
TODOs / ideas#
- Add a checkpoint callback with orbax-checkpoint?
fit #
fit(
algo: JaxModule[Ts, _B, _MetricsT],
rng: PRNGKey,
train_state: Ts | None = None,
skip_initial_evaluation: bool = False,
) -> tuple[Ts, _MetricsT]
Full training loop in pure jax (a lot faster than when using pytorch-lightning).
Unfolded version of rejax.PPO.train
.
Training loop in pure jax (a lot faster than when using pytorch-lightning).