Skip to content

Supervised Learning (PyTorch)#

The ExampleAlgorithm is a simple LightningModule for image classification.

Click to show the code for ExampleAlgorithm
class ExampleAlgorithm(LightningModule):
    """Example learning algorithm for image classification."""

    def __init__(
        self,
        datamodule: ImageClassificationDataModule,
        network: _Config[torch.nn.Module],
        optimizer: _PartialConfig[Optimizer] = AdamConfig(lr=3e-4),
        init_seed: int = 42,
    ):
        """Create a new instance of the algorithm.

        Parameters:
            datamodule: Object used to load train/val/test data.
                See the lightning docs for [LightningDataModule][lightning.pytorch.core.datamodule.LightningDataModule]
                for more info.
            network:
                The config of the network to instantiate and train.
            optimizer: The config for the Optimizer. Instantiating this will return a function \
                (a [functools.partial][]) that will create the Optimizer given the hyper-parameters.
            init_seed: The seed to use when initializing the weights of the network.
        """
        super().__init__()
        self.datamodule = datamodule
        self.network_config = network
        self.optimizer_config = optimizer
        self.init_seed = init_seed

        # Save hyper-parameters.
        self.save_hyperparameters(
            {
                "network_config": self.network_config,
                "optimizer_config": self.optimizer_config,
                "init_seed": init_seed,
            }
        )

        # Small fix for the `device` property in LightningModule, which is CPU by default.
        self._device = next((p.device for p in self.parameters()), torch.device("cpu"))
        # Used by Pytorch-Lightning to compute the input/output shapes of the network.
        self.example_input_array = torch.zeros(
            (datamodule.batch_size, *datamodule.dims), device=self.device
        )

        with torch.random.fork_rng():
            # deterministic weight initialization
            torch.manual_seed(self.init_seed)
            self.network = instantiate(self.network_config)

            if any(torch.nn.parameter.is_lazy(p) for p in self.network.parameters()):
                # Do a forward pass to initialize any lazy weights. This is necessary for
                # distributed training and to infer shapes.
                _ = self.network(self.example_input_array)

    def forward(self, input: Tensor) -> Tensor:
        """Forward pass of the network."""
        logits = self.network(input)
        return logits

    def training_step(self, batch: tuple[Tensor, Tensor], batch_index: int):
        return self.shared_step(batch, batch_index=batch_index, phase="train")

    def validation_step(self, batch: tuple[Tensor, Tensor], batch_index: int):
        return self.shared_step(batch, batch_index=batch_index, phase="val")

    def test_step(self, batch: tuple[Tensor, Tensor], batch_index: int):
        return self.shared_step(batch, batch_index=batch_index, phase="test")

    def shared_step(
        self,
        batch: tuple[Tensor, Tensor],
        batch_index: int,
        phase: Literal["train", "val", "test"],
    ):
        x, y = batch
        logits: torch.Tensor = self(x)
        loss = F.cross_entropy(logits, y, reduction="mean")
        self.log(f"{phase}/loss", loss.detach().mean())
        acc = logits.detach().argmax(-1).eq(y).float().mean()
        self.log(f"{phase}/accuracy", acc)
        return {"loss": loss, "logits": logits, "y": y}

    def configure_optimizers(self):
        """Creates the optimizers.

        See [`lightning.pytorch.core.LightningModule.configure_optimizers`][] for more information.
        """
        # Instantiate the optimizer config into a functools.partial object.
        optimizer_partial = instantiate(self.optimizer_config)
        # Call the functools.partial object, passing the parameters as an argument.
        optimizer = optimizer_partial(self.parameters())
        # This then returns the optimizer.
        return optimizer

    def configure_callbacks(self) -> Sequence[Callback] | Callback:
        """Creates callbacks to be used by default during training."""
        return [
            ClassificationMetricsCallback.attach_to(self, num_classes=self.datamodule.num_classes)
        ]

Here is a configuration file that you can use to launch a simple experiment:

Click to show the yaml config file
# @package _global_

# This is an "experiment" config, that groups together other configs into a ready-to-run example.

# To execute this experiment, use:
# python project/main.py experiment=example

defaults:
  - override /algorithm: example
  - override /algorithm/network: resnet18
  - override /datamodule: cifar10
  - override /trainer: default
  - override /trainer/logger: tensorboard
  - override /trainer/callbacks: default

# The parameters below will be merged with parameters from default configurations set above.
# This allows you to overwrite only specified parameters

# The name of the e
name: example

seed: ${oc.env:SLURM_PROCID,42}

algorithm:
  optimizer:
    lr: 0.002

datamodule:
  batch_size: 64

trainer:
  min_epochs: 1
  max_epochs: 10
  gradient_clip_val: 0.5

You can use it like so:

python project/main.py experiment=example