Skip to content

Jax + PyTorch-Lightning ⚑#

A LightningModule that trains a Jax network#

The JaxImageClassifier algorithm uses a network which is a flax.linen.Module. The network is wrapped with torch_jax_interop.JaxFunction, so that it can accept torch tensors as inputs, produces torch tensors as outputs, and the parameters are saved as as torch.nn.Parameters (which use the same underlying memory as the jax arrays). In this example, the loss function and optimizers are in PyTorch, while the network forward and backward passes are written in Jax.

The loss that is returned in the training step is used by Lightning in the usual way. The backward pass uses Jax to calculate the gradients, and the weights are updated by a PyTorch optimizer.

Info

You could also very well do both the forward and backward passes in Jax! To do this, use the 'manual optimization' mode of PyTorch-Lightning and perform the parameter updates yourself. For the rest of Lightning to work, just make sure to store the parameters as torch.nn.Parameters. An example of how to do this will be added shortly.

What about end-to-end training in Jax?

See the Jax RL Example! πŸ˜„

Jax Network#

class JaxCNN(flax.linen.Module):
    """A simple CNN model.

    Taken from https://flax.readthedocs.io/en/latest/quick_start.html#define-network
    """

    num_classes: int = 10

    @flax.linen.compact
    def __call__(self, x: jax.Array):
        x = to_channels_last(x)
        x = flax.linen.Conv(features=32, kernel_size=(3, 3))(x)
        x = flax.linen.relu(x)
        x = flax.linen.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = flax.linen.Conv(features=64, kernel_size=(3, 3))(x)
        x = flax.linen.relu(x)
        x = flax.linen.avg_pool(x, window_shape=(2, 2), strides=(2, 2))

        x = flatten(x)
        x = flax.linen.Dense(features=256)(x)
        x = flax.linen.relu(x)
        x = flax.linen.Dense(features=self.num_classes)(x)
        return x

Jax Algorithm#

class JaxImageClassifier(LightningModule):
    """Example of a learning algorithm (`LightningModule`) that uses Jax.

    In this case, the network is a flax.linen.Module, and its forward and backward passes are
    written in Jax, and the loss function is in pytorch.
    """

    def __init__(
        self,
        datamodule: ImageClassificationDataModule,
        network: HydraConfigFor[flax.linen.Module],
        optimizer: HydraConfigFor[functools.partial[Optimizer]],
        init_seed: int = 123,
        debug: bool = True,
    ):
        super().__init__()
        self.datamodule = datamodule
        self.network_config = network
        self.optimizer_config = optimizer
        self.init_seed = init_seed
        self.debug = debug

        # Create the jax network (safe to do even on CPU here).
        self.jax_network: flax.linen.Module = hydra_zen.instantiate(self.network_config)
        # We'll instantiate the parameters and the torch wrapper around the jax network in
        # `configure_model` so the weights are directly on the GPU.
        self.network: torch.nn.Module | None = None
        self.save_hyperparameters(ignore=["datamodule"])

    def configure_model(self):
        example_input = torch.zeros(
            (self.datamodule.batch_size, *self.datamodule.dims),
        )
        # Save this for PyTorch-Lightning to infer the input/output shapes of the network.
        self.example_input_array = example_input

        # Initialize the jax parameters with a forward pass.
        jax_params = self.jax_network.init(
            jax.random.key(self.init_seed), torch_to_jax(example_input)
        )

        jax_network_forward = self.jax_network.apply
        if not self.debug:
            jax_network_forward = jax.jit(jax_network_forward)

        # Wrap the jax network into a nn.Module:
        self.network = WrappedJaxFunction(
            jax_function=jax_network_forward,
            jax_params=jax_params,
            # Need to call .clone() when doing distributed training, otherwise we get a RuntimeError:
            # Invalid device pointer when trying to share the CUDA tensors that come from jax.
            clone_params=True,
            has_aux=False,
        )

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        assert self.network is not None
        logits = self.network(input)
        return logits

    def training_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_index: int):
        return self.shared_step(batch, batch_index=batch_index, phase="train")

    def validation_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_index: int):
        return self.shared_step(batch, batch_index=batch_index, phase="val")

    def test_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_index: int):
        return self.shared_step(batch, batch_index=batch_index, phase="test")

    def shared_step(
        self,
        batch: tuple[torch.Tensor, torch.Tensor],
        batch_index: int,
        phase: Literal["train", "val", "test"],
    ):
        # This is the same thing as the `ImageClassifier.shared_step`!
        x, y = batch
        assert not x.requires_grad
        assert self.network is not None
        logits = self.network(x)
        assert isinstance(logits, torch.Tensor)
        # In this example we use a jax "encoder" network and a PyTorch loss function, but we could
        # also just as easily have done the whole forward and backward pass in jax if we wanted to.
        loss = F.cross_entropy(logits, y, reduction="mean")
        acc = logits.argmax(-1).eq(y).float().mean()
        self.log(f"{phase}/loss", loss, prog_bar=True, sync_dist=True)
        self.log(f"{phase}/acc", acc, prog_bar=True, sync_dist=True)
        return {"loss": loss, "logits": logits, "y": y}

    def configure_optimizers(self):
        """Creates the optimizers.

        See [`lightning.pytorch.core.LightningModule.configure_optimizers`][] for more information.
        """
        # Instantiate the optimizer config into a functools.partial object.
        optimizer_partial = hydra_zen.instantiate(self.optimizer_config)
        # Call the functools.partial object, passing the parameters as an argument.
        optimizer = optimizer_partial(self.parameters())
        # This then returns the optimizer.
        return optimizer

    def configure_callbacks(self) -> list[Callback]:
        assert isinstance(self.datamodule, ImageClassificationDataModule)
        return [
            MeasureSamplesPerSecondCallback(),
            ClassificationMetricsCallback.attach_to(self, num_classes=self.datamodule.num_classes),
        ]

Configs#

LightningModule config#

# Config for the JaxImageClassifier algorithm
defaults:
  - network: jax_cnn
  - optimizer: SGD
_target_: project.algorithms.jax_image_classifier.JaxImageClassifier
# NOTE: Why _partial_ here? Because the config doesn't create the algo directly.
# The datamodule is instantiated first and then passed to the algorithm.
_partial_: true
_recursive_: false

optimizer:
  lr: 0.001

init_seed: 123
debug: False

Running the example#

$ python project/main.py algorithm=jax_image_classifier network=jax_cnn datamodule=cifar10