Skip to content

Using Jax with PyTorch-Lightning#

πŸ”₯ NOTE: This is a feature that is entirely unique to this template! πŸ”₯

This template includes examples that use either Jax, PyTorch, or both!

Example link Reference Framework Lightning?
ExampleAlgorithm ExampleAlgorithm Torch yes
JaxExample JaxExample Torch + Jax yes
TextClassificationExample TextClassificationExample Torch + πŸ€— yes
JaxRLExample JaxRLExample Jax no (almost!)

In fact, here you can mix and match both Jax and Torch code. For example, you can use Jax for your dataloading, your network, or the learning algorithm, all while still benefiting from the nice stuff that comes from using PyTorch-Lightning.

How does this work?

Well, we use torch-jax-interop, another package developed here at Mila 😎, that allows easy interop between torch and jax code. Feel free to take a look at it if you'd like to use it as part of your own project. 😁

Using PyTorch-Lightning to train a Jax network#

If you'd like to use Jax in your network or learning algorithm, while keeping the same style of training loop as usual, you can!

  • Use Jax for the forward / backward passes, the parameter updates, dataset preprocessing, etc.
  • Leave the training loop / callbacks / logging / checkpointing / etc to Lightning

The lightning.Trainer will not be able to tell that you're using Jax!

Take a look at this image classification example that uses a Jax network.

End-to-end training in Jax: the JaxTrainer#

The JaxTrainer, used in the Jax RL Example, follows a similar structure as the lightning Trainer. However, instead of training LightningModules, it trains JaxModules, which are a simplified, jax-based look-alike of lightning.LightningModules.

The "algorithm" needs to match the JaxModule protocol: - JaxModule.training_step: train using a batch of data