Skip to content

Algorithm tests

Suite of tests for an "algorithm".

See the project.algorithms.example_test module for an example of how to use this.

LearningAlgorithmTests #

Bases: Generic[AlgorithmType], ABC

Suite of unit tests for an "Algorithm" (LightningModule).

Simply inherit from this class and decorate the class with the appropriate markers to get a set of decent unit tests that should apply to any LightningModule.

See the project.algorithms.example_test module for an example.

test_initialization_is_deterministic #

test_initialization_is_deterministic(
    experiment_config: Config,
    datamodule: DataModule,
    seed: int,
)

Checks that the weights initialization is consistent given the a random seed.

test_forward_pass_is_deterministic #

test_forward_pass_is_deterministic(
    forward_pass_input: Any,
    algorithm: AlgorithmType,
    seed: int,
)

Checks that the forward pass output is consistent given the a random seed and a given input.

test_backward_pass_is_deterministic #

test_backward_pass_is_deterministic(
    datamodule: LightningDataModule,
    algorithm: AlgorithmType,
    seed: int,
    accelerator: str,
    devices: int | list[int] | Literal["auto"],
    tmp_path: Path,
)

Check that the backward pass is reproducible given the same input, weights, and random seed.

test_initialization_is_reproducible #

test_initialization_is_reproducible(
    experiment_config: Config,
    datamodule: DataModule,
    seed: int,
    tensor_regression: TensorRegressionFixture,
)

Check that the network initialization is reproducible given the same random seed.

test_forward_pass_is_reproducible #

test_forward_pass_is_reproducible(
    forward_pass_input: Any,
    algorithm: AlgorithmType,
    seed: int,
    tensor_regression: TensorRegressionFixture,
)

Check that the forward pass is reproducible given the same input and random seed.

test_backward_pass_is_reproducible #

test_backward_pass_is_reproducible(
    datamodule: LightningDataModule,
    algorithm: AlgorithmType,
    seed: int,
    accelerator: str,
    devices: int | list[int],
    tensor_regression: TensorRegressionFixture,
    tmp_path: Path,
)

Check that the backward pass is reproducible given the same weights, inputs and random seed.

forward_pass_input #

forward_pass_input(
    training_batch: PyTree[Tensor], device: device
)

Extracts the model input from a batch of data coming from the dataloader.

Overwrite this if your batches are not tuples of tensors (i.e. if your algorithm isn't a simple supervised learning algorithm like the example).

do_one_step_of_training #

do_one_step_of_training(
    algorithm: AlgorithmType,
    datamodule: LightningDataModule,
    accelerator: str,
    devices: int | list[int] | Literal["auto"],
    callbacks: list[Callback],
    tmp_path: Path,
)

Performs one step of training.

Overwrite this if you train your algorithm differently.

forward_pass #

forward_pass(
    algorithm: LightningModule, input: PyTree[Tensor]
)

Performs the forward pass with the lightningmodule, unpacking the inputs if necessary.