Skip to content

Using Jax with PyTorch-Lightning#

You can use Jax for your dataloading, your network, or the learning algorithm, all while still benefiting from the nice stuff that comes from using PyTorch-Lightning.

How does this work? Well, we use torch-jax-interop, another package developed here at Mila, which allows easy interop between torch and jax code. See the readme on that repo for more details.

You can use Jax in your network or learning algorithm, for example in your forward / backward passes, to update parameters, etc. but not the training loop itself, since that is handled by the lightning.Trainer. There are lots of good reasons why you might want to let Lightning handle the training loop. which are very well described here.

What about end-to-end training in Jax?

This template doesn't include a way to do end-to-end, fully-jitted training in Jax, however, it might be possible to do so in this way:

  • add a new configuration in the trainer config group, with a _target_ pointing to a trainer-like object with a fit, evaluate and test method mimicking those of PyTorch-Lightning.
  • add a new configuration in the algorithm config group pointing to a learning algorithm class that isn't a LightningModule.

If you want an example of how to do this, please make an issue (or like an existing issue) on GitHub.

JaxExample: a LightningModule that uses Jax#

The JaxExample algorithm uses a network which is a flax.linen.Module. The network is wrapped with torch_jax_interop.JaxFunction, so that it can accept torch tensors as inputs, produces torch tensors as outputs, and the parameters are saved as as torch.nn.Parameters (which use the same underlying memory as the jax arrays). In this example, the loss function and optimizers are in PyTorch, while the network forward and backward passes are written in Jax.

The loss that is returned in the training step is used by Lightning in the usual way. The backward pass uses Jax to calculate the gradients, and the weights are updated by a PyTorch optimizer.

Note

You could also very well do both the forward and backward passes in Jax! To do this, use the 'manual optimization' mode of PyTorch-Lightning and perform the parameter updates yourself. For the rest of Lightning to work, just make sure to store the parameters as torch.nn.Parameters. An example of how to do this will be added shortly.

Jax Network#

class CNN(flax.linen.Module):
    """A simple CNN model.

    Taken from https://flax.readthedocs.io/en/latest/quick_start.html#define-network
    """

    num_classes: int = 10

    @flax.linen.compact
    def __call__(self, x: jax.Array):
        x = to_channels_last(x)
        x = flax.linen.Conv(features=32, kernel_size=(3, 3))(x)
        x = flax.linen.relu(x)
        x = flax.linen.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = flax.linen.Conv(features=64, kernel_size=(3, 3))(x)
        x = flax.linen.relu(x)
        x = flax.linen.avg_pool(x, window_shape=(2, 2), strides=(2, 2))

        x = flatten(x)
        x = flax.linen.Dense(features=256)(x)
        x = flax.linen.relu(x)
        x = flax.linen.Dense(features=self.num_classes)(x)
        return x

Jax Algorithm#

class JaxExample(LightningModule):
    """Example of a learning algorithm (`LightningModule`) that uses Jax.

    In this case, the network is a flax.linen.Module, and its forward and backward passes are
    written in Jax, and the loss function is in pytorch.
    """

    @dataclasses.dataclass(frozen=True)
    class HParams:
        """Hyper-parameters of the algo."""

        lr: float = 1e-3
        seed: int = 123
        debug: bool = True

    def __init__(
        self,
        *,
        network: flax.linen.Module,
        datamodule: ImageClassificationDataModule,
        hp: HParams = HParams(),
    ):
        super().__init__()
        self.datamodule = datamodule
        self.hp = hp or self.HParams()

        example_input = torch.zeros(
            (datamodule.batch_size, *datamodule.dims),
            device=self.device,
        )
        # Initialize the jax parameters with a forward pass.
        params = network.init(jax.random.key(self.hp.seed), x=torch_to_jax(example_input))

        # Wrap the jax network into a nn.Module:
        self.network = WrappedJaxFunction(
            jax_function=jax.jit(network.apply) if not self.hp.debug else network.apply,
            jax_params=params,
            # Need to call .clone() when doing distributed training, otherwise we get a RuntimeError:
            # Invalid device pointer when trying to share the CUDA tensors that come from jax.
            clone_params=True,
            has_aux=False,
        )

        self.example_input_array = example_input

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        logits = self.network(input)
        return logits

    def training_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_index: int):
        return self.shared_step(batch, batch_index=batch_index, phase="train")

    def validation_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_index: int):
        return self.shared_step(batch, batch_index=batch_index, phase="val")

    def test_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_index: int):
        return self.shared_step(batch, batch_index=batch_index, phase="test")

    def shared_step(
        self,
        batch: tuple[torch.Tensor, torch.Tensor],
        batch_index: int,
        phase: Literal["train", "val", "test"],
    ):
        x, y = batch
        assert not x.requires_grad
        logits = self.network(x)
        assert isinstance(logits, torch.Tensor)
        # In this example we use a jax "encoder" network and a PyTorch loss function, but we could
        # also just as easily have done the whole forward and backward pass in jax if we wanted to.
        loss = torch.nn.functional.cross_entropy(logits, target=y, reduction="mean")
        acc = logits.argmax(-1).eq(y).float().mean()
        self.log(f"{phase}/loss", loss, prog_bar=True, sync_dist=True)
        self.log(f"{phase}/acc", acc, prog_bar=True, sync_dist=True)
        return {"loss": loss, "logits": logits, "y": y}

    def configure_optimizers(self):
        return torch.optim.SGD(self.parameters(), lr=self.hp.lr)

    def configure_callbacks(self) -> list[Callback]:
        assert isinstance(self.datamodule, ClassificationDataModule)
        return [
            MeasureSamplesPerSecondCallback(),
            ClassificationMetricsCallback.attach_to(self, num_classes=self.datamodule.num_classes),
        ]

Configs#

JaxExample algorithm config#

# Config for the JaxExample algorithm
defaults:
  - network: jax_cnn

_target_: project.algorithms.jax_example.JaxExample
# NOTE: Why _partial_ here? Because the config doesn't create the algo directly.
# The datamodule is instantiated first and then passed to the algorithm.
_partial_: true
hp:
  lr: 0.001
  seed: 123
  debug: False

Running the example#

$ python project/main.py algorithm=jax_example network=jax_cnn datamodule=cifar10