Image classifier

Example of a simple algorithm for image classification.

This can be run from the command-line like so:

python project/ algorithm=image_classification datamodule=cifar10

ImageClassifier #

Bases: LightningModule

Example learning algorithm for image classification.

__init__ #

    datamodule: ImageClassificationDataModule,
    network: HydraConfigFor[Module],
    optimizer: HydraConfigFor[partial[Optimizer]],
    init_seed: int = 42,

Create a new instance of the algorithm.


Name Type Description Default
datamodule ImageClassificationDataModule

Object used to load train/val/test data. See the lightning docs for LightningDataModule for more info.

network HydraConfigFor[Module]

The config of the network to instantiate and train.

optimizer HydraConfigFor[partial[Optimizer]]

The config for the Optimizer. Instantiating this will return a function (a functools.partial) that will create the Optimizer given the hyper-parameters.

init_seed int

The seed to use when initializing the weights of the network.


forward #

forward(input: Tensor) -> Tensor

Forward pass of the network.

configure_optimizers #


Creates the optimizers.

See lightning.pytorch.core.LightningModule.configure_optimizers for more information.

configure_callbacks #

configure_callbacks() -> Sequence[Callback] | Callback

Creates callbacks to be used by default during training.