Skip to content

Jax + PyTorch-Lightning ⚑#

JaxExample: a LightningModule that trains a Jax network#

The JaxExample algorithm uses a network which is a flax.linen.Module. The network is wrapped with torch_jax_interop.JaxFunction, so that it can accept torch tensors as inputs, produces torch tensors as outputs, and the parameters are saved as as torch.nn.Parameters (which use the same underlying memory as the jax arrays). In this example, the loss function and optimizers are in PyTorch, while the network forward and backward passes are written in Jax.

The loss that is returned in the training step is used by Lightning in the usual way. The backward pass uses Jax to calculate the gradients, and the weights are updated by a PyTorch optimizer.

Info

You could also very well do both the forward and backward passes in Jax! To do this, use the 'manual optimization' mode of PyTorch-Lightning and perform the parameter updates yourself. For the rest of Lightning to work, just make sure to store the parameters as torch.nn.Parameters. An example of how to do this will be added shortly.

What about end-to-end training in Jax?

See the Jax RL Example! πŸ˜„

Jax Network#

class CNN(flax.linen.Module):
    """A simple CNN model.

    Taken from https://flax.readthedocs.io/en/latest/quick_start.html#define-network
    """

    num_classes: int = 10

    @flax.linen.compact
    def __call__(self, x: jax.Array):
        x = to_channels_last(x)
        x = flax.linen.Conv(features=32, kernel_size=(3, 3))(x)
        x = flax.linen.relu(x)
        x = flax.linen.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = flax.linen.Conv(features=64, kernel_size=(3, 3))(x)
        x = flax.linen.relu(x)
        x = flax.linen.avg_pool(x, window_shape=(2, 2), strides=(2, 2))

        x = flatten(x)
        x = flax.linen.Dense(features=256)(x)
        x = flax.linen.relu(x)
        x = flax.linen.Dense(features=self.num_classes)(x)
        return x

Jax Algorithm#

class JaxExample(LightningModule):
    """Example of a learning algorithm (`LightningModule`) that uses Jax.

    In this case, the network is a flax.linen.Module, and its forward and backward passes are
    written in Jax, and the loss function is in pytorch.
    """

    @dataclasses.dataclass(frozen=True)
    class HParams:
        """Hyper-parameters of the algo."""

        lr: float = 1e-3
        seed: int = 123
        debug: bool = True

    def __init__(
        self,
        *,
        network: flax.linen.Module,
        datamodule: ImageClassificationDataModule,
        hp: HParams = HParams(),
    ):
        super().__init__()
        os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

        self.datamodule = datamodule
        self.hp = hp or self.HParams()

        example_input = torch.zeros(
            (datamodule.batch_size, *datamodule.dims),
            device=self.device,
        )
        # Initialize the jax parameters with a forward pass.
        params = network.init(jax.random.key(self.hp.seed), x=torch_to_jax(example_input))

        # Wrap the jax network into a nn.Module:
        self.network = WrappedJaxFunction(
            jax_function=jax.jit(network.apply) if not self.hp.debug else network.apply,
            jax_params=params,
            # Need to call .clone() when doing distributed training, otherwise we get a RuntimeError:
            # Invalid device pointer when trying to share the CUDA tensors that come from jax.
            clone_params=True,
            has_aux=False,
        )

        self.example_input_array = example_input

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        logits = self.network(input)
        return logits

    def training_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_index: int):
        return self.shared_step(batch, batch_index=batch_index, phase="train")

    def validation_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_index: int):
        return self.shared_step(batch, batch_index=batch_index, phase="val")

    def test_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_index: int):
        return self.shared_step(batch, batch_index=batch_index, phase="test")

    def shared_step(
        self,
        batch: tuple[torch.Tensor, torch.Tensor],
        batch_index: int,
        phase: Literal["train", "val", "test"],
    ):
        x, y = batch
        assert not x.requires_grad
        logits = self.network(x)
        assert isinstance(logits, torch.Tensor)
        # In this example we use a jax "encoder" network and a PyTorch loss function, but we could
        # also just as easily have done the whole forward and backward pass in jax if we wanted to.
        loss = torch.nn.functional.cross_entropy(logits, target=y, reduction="mean")
        acc = logits.argmax(-1).eq(y).float().mean()
        self.log(f"{phase}/loss", loss, prog_bar=True, sync_dist=True)
        self.log(f"{phase}/acc", acc, prog_bar=True, sync_dist=True)
        return {"loss": loss, "logits": logits, "y": y}

    def configure_optimizers(self):
        return torch.optim.SGD(self.parameters(), lr=self.hp.lr)

    def configure_callbacks(self) -> list[Callback]:
        assert isinstance(self.datamodule, ClassificationDataModule)
        return [
            MeasureSamplesPerSecondCallback(),
            ClassificationMetricsCallback.attach_to(self, num_classes=self.datamodule.num_classes),
        ]

Configs#

JaxExample algorithm config#

# Config for the JaxExample algorithm
defaults:
  - network: jax_cnn

_target_: project.algorithms.jax_example.JaxExample
# NOTE: Why _partial_ here? Because the config doesn't create the algo directly.
# The datamodule is instantiated first and then passed to the algorithm.
_partial_: true
hp:
  lr: 0.001
  seed: 123
  debug: False

Running the example#

$ python project/main.py algorithm=jax_example network=jax_cnn datamodule=cifar10