Skip to content

Trainers

JaxTrainer #

Bases: PyTreeNode

A simplified version of the lightning.Trainer with a fully jitted training loop.

Assumptions:#

  • The algo object must match the JaxModule protocol (in other words, it should implement its methods).

Training loop#

This is the training loop, which is fully jitted:

ts = algo.init_train_state(rng)

setup("fit")
on_fit_start()
on_train_start()

eval_metrics = []
for epoch in range(self.max_epochs):
    on_train_epoch_start()

    for step in range(self.training_steps_per_epoch):

        batch = algo.get_batch(ts, step)

        on_train_batch_start()

        ts, metrics = algo.training_step(step, ts, batch)

        on_train_batch_end()

    on_train_epoch_end()

    # Evaluation "loop"
    on_validation_epoch_start()
    epoch_eval_metrics = self.eval_epoch(ts, epoch, algo)
    on_validation_epoch_start()

    eval_metrics.append(epoch_eval_metrics)

return ts, eval_metrics

Caveats#

  • Some lightning callbacks can be used with this trainer and work well, but not all of them.
  • You can either use Regular pytorch-lightning callbacks, or use jax.vmap on the fit method, but not both.
  • If you want to use jax.vmap on the fit method, just remove the callbacks on the Trainer for now.

TODOs / ideas#

  • Add a checkpoint callback with orbax-checkpoint?

fit #

fit(
    algo: JaxModule[Ts, _B, _MetricsT],
    rng: PRNGKey,
    train_state: Ts | None = None,
    skip_initial_evaluation: bool = False,
) -> tuple[Ts, _MetricsT]

Full training loop in pure jax (a lot faster than when using pytorch-lightning).

Unfolded version of rejax.PPO.train.

Training loop in pure jax (a lot faster than when using pytorch-lightning).

training_step #

training_step(
    batch_idx: int,
    ts: Ts,
    algo: JaxModule[Ts, _B, _MetricsT],
)

Training step in pure jax (joined data collection + training).

MUCH faster than using pytorch-lightning, but you lose the callbacks and such.