Jax trainer
A simplified version of the lightning.Trainer
with a fully jitted training loop.
This is used by the JaxRLExample
algorithm (PPO) in the jax_ppo.py
module.
Ts
module-attribute
#
Ts = TypeVar('Ts', bound=PyTreeNode, default=PyTreeNode)
Type Variable for the training state.
JaxModule #
Bases: Protocol[Ts, _B, _MetricsT]
A protocol for algorithms that can be trained by the JaxTrainer
.
The JaxRLExample
is an example that follows this structure and can be trained with a
JaxTrainer
.
JaxTrainer #
Bases: PyTreeNode
A simplified version of the lightning.Trainer
with a fully jitted training loop.
Assumptions:#
- The algo object must match the
JaxModule
protocol (in other words, it should implement its methods).
Training loop#
This is the training loop, which is fully jitted:
ts = algo.init_train_state(rng)
setup("fit")
on_fit_start()
on_train_start()
eval_metrics = []
for epoch in range(self.max_epochs):
on_train_epoch_start()
for step in range(self.training_steps_per_epoch):
batch = algo.get_batch(ts, step)
on_train_batch_start()
ts, metrics = algo.training_step(step, ts, batch)
on_train_batch_end()
on_train_epoch_end()
# Evaluation "loop"
on_validation_epoch_start()
epoch_eval_metrics = self.eval_epoch(ts, epoch, algo)
on_validation_epoch_start()
eval_metrics.append(epoch_eval_metrics)
return ts, eval_metrics
Caveats#
- Some lightning callbacks can be used with this trainer and work well, but not all of them.
- You can either use Regular pytorch-lightning callbacks, or use
jax.vmap
on thefit
method, but not both. - If you want to use jax.vmap on the
fit
method, just remove the callbacks on the Trainer for now.
TODOs / ideas#
- Add a checkpoint callback with orbax-checkpoint?
is_global_zero
property
#
is_global_zero: bool
Check if the current process is the global zero process in a distributed setup.
get_error_from_metrics #
Return the 'error' to minimize for hyperparameter optimization from a set of metrics.
hparams_to_dict #
hparams_to_dict(algo: PyTreeNode) -> dict
Convert the learner struct to a serializable dict.