Skip to content

Jax trainer

Ts module-attribute #

Ts = TypeVar('Ts', bound=PyTreeNode, default=PyTreeNode)

Type Variable for the training state.

JaxModule #

Bases: Protocol[Ts, _B, _MetricsT]

A protocol for algorithms that can be trained by the JaxTrainer.

The JaxRLExample is an example that follows this structure and can be trained with a JaxTrainer.

init_train_state #

init_train_state(rng: PRNGKey) -> Ts

Create the initial training state.

get_batch #

get_batch(ts: Ts, batch_idx: int) -> tuple[Ts, _B]

Produces a batch of data.

training_step #

training_step(
    batch_idx: int, ts: Ts, batch: _B
) -> tuple[Ts, PyTreeNode]

Update the training state using a "batch" of data.

eval_callback #

eval_callback(ts: Ts) -> _MetricsT

Perform evaluation and return metrics.

JaxTrainer #

Bases: PyTreeNode

A simplified version of the lightning.Trainer with a fully jitted training loop.

Assumptions:#

  • The algo object must match the JaxModule protocol (in other words, it should implement its methods).

Training loop#

This is the training loop, which is fully jitted:

ts = algo.init_train_state(rng)

setup("fit")
on_fit_start()
on_train_start()

eval_metrics = []
for epoch in range(self.max_epochs):
    on_train_epoch_start()

    for step in range(self.training_steps_per_epoch):

        batch = algo.get_batch(ts, step)

        on_train_batch_start()

        ts, metrics = algo.training_step(step, ts, batch)

        on_train_batch_end()

    on_train_epoch_end()

    # Evaluation "loop"
    on_validation_epoch_start()
    epoch_eval_metrics = self.eval_epoch(ts, epoch, algo)
    on_validation_epoch_start()

    eval_metrics.append(epoch_eval_metrics)

return ts, eval_metrics

Caveats#

  • Some lightning callbacks can be used with this trainer and work well, but not all of them.
  • You can either use Regular pytorch-lightning callbacks, or use jax.vmap on the fit method, but not both.
  • If you want to use jax.vmap on the fit method, just remove the callbacks on the Trainer for now.

TODOs / ideas#

  • Add a checkpoint callback with orbax-checkpoint?

fit #

fit(
    algo: JaxModule[Ts, _B, _MetricsT],
    rng: PRNGKey,
    train_state: Ts | None = None,
    skip_initial_evaluation: bool = False,
) -> tuple[Ts, _MetricsT]

Full training loop in pure jax (a lot faster than when using pytorch-lightning).

Unfolded version of rejax.PPO.train.

Training loop in pure jax (a lot faster than when using pytorch-lightning).

training_step #

training_step(
    batch_idx: int,
    ts: Ts,
    algo: JaxModule[Ts, _B, _MetricsT],
)

Training step in pure jax (joined data collection + training).

MUCH faster than using pytorch-lightning, but you lose the callbacks and such.

hparams_to_dict #

hparams_to_dict(algo: PyTreeNode) -> dict

Convert the learner struct to a serializable dict.