Skip to content

Lightning module tests

Suite of tests for an a LightningModule.

See the project.algorithms.image_classifier_test module for an example of how to use this.

LightningModuleTests #

Bases: Generic[AlgorithmType], ABC

Suite of generic tests for a LightningModule.

Simply inherit from this class and decorate the class with the appropriate markers to get a set of decent unit tests that should apply to any LightningModule.

See the project.algorithms.image_classifier_test module for an example.

Other ideas: - pytest-benchmark for regression tests on forward / backward pass / training step speed - pytest-profiling for profiling the training step? (pytorch variant?) - Dataset splits: check some basic stats about the train/val/test inputs, are they somewhat similar? - Define the input as a space, check that the dataset samples are in that space and not too many samples are statistically OOD?

experiment_config #

experiment_config(
    experiment_dictconfig: DictConfig,
) -> Config

The experiment configuration, with all interpolations resolved.

algorithm #

algorithm(
    experiment_config: Config,
    datamodule: LightningDataModule | None,
    trainer: Trainer | JaxTrainer,
    device: device,
)

Fixture that creates the "algorithm" (a LightningModule).

make_torch_deterministic #

make_torch_deterministic()

Set torch to deterministic mode for unit tests that use the tensor_regression fixture.

seed #

seed(request: FixtureRequest)

Fixture that seeds everything for reproducibility and yields the random seed used.

training_step_content #

training_step_content(
    datamodule: LightningDataModule | None,
    algorithm: AlgorithmType,
    seed: int,
    accelerator: str,
    devices: int | list[int],
    tmp_path_factory: TempPathFactory,
)

Check that the backward pass is reproducible given the same weights, inputs and random seed.

test_initialization_is_reproducible #

test_initialization_is_reproducible(
    training_step_content: tuple[
        AlgorithmType,
        GetStuffFromFirstTrainingStep,
        list[Any],
        list[Any],
    ],
    tensor_regression: TensorRegressionFixture,
    accelerator: str,
)

Check that the network initialization is reproducible given the same random seed.

test_forward_pass_is_reproducible #

test_forward_pass_is_reproducible(
    training_step_content: tuple[
        AlgorithmType,
        GetStuffFromFirstTrainingStep,
        list[Any],
        list[Any],
    ],
    tensor_regression: TensorRegressionFixture,
)

Check that the forward pass is reproducible given the same input and random seed.

test_backward_pass_is_reproducible #

test_backward_pass_is_reproducible(
    training_step_content: tuple[
        AlgorithmType,
        GetStuffFromFirstTrainingStep,
        list[Any],
        list[Any],
    ],
    tensor_regression: TensorRegressionFixture,
    accelerator: str,
)

Check that the backward pass is reproducible given the same weights, inputs and random seed.

forward_pass_input #

forward_pass_input(
    training_batch: PyTree[Tensor], device: device
)

Extracts the model input from a batch of data coming from the dataloader.

Overwrite this if your batches are not tuples of tensors (i.e. if your algorithm isn't a simple supervised learning algorithm like the example).

do_one_step_of_training #

do_one_step_of_training(
    algorithm: AlgorithmType,
    datamodule: LightningDataModule | None,
    accelerator: str,
    devices: int | list[int] | Literal["auto"],
    callbacks: list[Callback],
    tmp_path: Path,
)

Performs one step of training.

Overwrite this if you train your algorithm differently.

GetStuffFromFirstTrainingStep #

Bases: Callback

Callback used in tests to get things from the first call to training_step.

convert_list_and_tuples_to_dicts #

convert_list_and_tuples_to_dicts(value: Any) -> Any

Converts all lists and tuples in a nested structure to dictionaries.

convert_list_and_tuples_to_dicts([1, 2, 3]) {'0': 1, '1': 2, '2': 3} convert_list_and_tuples_to_dicts((1, 2, 3)) {'0': 1, '1': 2, '2': 3} convert_list_and_tuples_to_dicts({"a": [1, 2, 3], "b": (4, 5, 6)}) {'a': {'0': 1, '1': 2, '2': 3}, 'b': {'0': 4, '1': 5, '2': 6}}