Lightning module tests
Suite of tests for an a LightningModule
.
See the project.algorithms.image_classifier_test module for an example of how to use this.
LightningModuleTests #
Bases: Generic[AlgorithmType]
, ABC
Suite of generic tests for a LightningModule.
Simply inherit from this class and decorate the class with the appropriate markers to get a set of decent unit tests that should apply to any LightningModule.
See the project.algorithms.image_classifier_test module for an example.
Other ideas: - pytest-benchmark for regression tests on forward / backward pass / training step speed - pytest-profiling for profiling the training step? (pytorch variant?) - Dataset splits: check some basic stats about the train/val/test inputs, are they somewhat similar? - Define the input as a space, check that the dataset samples are in that space and not too many samples are statistically OOD?
experiment_config #
experiment_config(
experiment_dictconfig: DictConfig,
) -> Config
The experiment configuration, with all interpolations resolved.
algorithm #
algorithm(
experiment_config: Config,
datamodule: LightningDataModule | None,
trainer: Trainer | JaxTrainer,
device: device,
)
Fixture that creates the "algorithm" (a LightningModule).
make_torch_deterministic #
Set torch to deterministic mode for unit tests that use the tensor_regression fixture.
seed #
seed(request: FixtureRequest)
Fixture that seeds everything for reproducibility and yields the random seed used.
training_step_content #
training_step_content(
datamodule: LightningDataModule | None,
algorithm: AlgorithmType,
seed: int,
accelerator: str,
devices: int | list[int],
tmp_path_factory: TempPathFactory,
)
Check that the backward pass is reproducible given the same weights, inputs and random seed.
test_initialization_is_reproducible #
test_initialization_is_reproducible(
training_step_content: tuple[
AlgorithmType,
GetStuffFromFirstTrainingStep,
list[Any],
list[Any],
],
tensor_regression: TensorRegressionFixture,
accelerator: str,
)
Check that the network initialization is reproducible given the same random seed.
test_forward_pass_is_reproducible #
test_forward_pass_is_reproducible(
training_step_content: tuple[
AlgorithmType,
GetStuffFromFirstTrainingStep,
list[Any],
list[Any],
],
tensor_regression: TensorRegressionFixture,
)
Check that the forward pass is reproducible given the same input and random seed.
test_backward_pass_is_reproducible #
test_backward_pass_is_reproducible(
training_step_content: tuple[
AlgorithmType,
GetStuffFromFirstTrainingStep,
list[Any],
list[Any],
],
tensor_regression: TensorRegressionFixture,
accelerator: str,
)
Check that the backward pass is reproducible given the same weights, inputs and random seed.
forward_pass_input #
Extracts the model input from a batch of data coming from the dataloader.
Overwrite this if your batches are not tuples of tensors (i.e. if your algorithm isn't a simple supervised learning algorithm like the example).
do_one_step_of_training #
do_one_step_of_training(
algorithm: AlgorithmType,
datamodule: LightningDataModule | None,
accelerator: str,
devices: int | list[int] | Literal["auto"],
callbacks: list[Callback],
tmp_path: Path,
)
Performs one step of training.
Overwrite this if you train your algorithm differently.
GetStuffFromFirstTrainingStep #
Bases: Callback
Callback used in tests to get things from the first call to training_step
.
convert_list_and_tuples_to_dicts #
Converts all lists and tuples in a nested structure to dictionaries.
convert_list_and_tuples_to_dicts([1, 2, 3]) {'0': 1, '1': 2, '2': 3} convert_list_and_tuples_to_dicts((1, 2, 3)) {'0': 1, '1': 2, '2': 3} convert_list_and_tuples_to_dicts({"a": [1, 2, 3], "b": (4, 5, 6)}) {'a': {'0': 1, '1': 2, '2': 3}, 'b': {'0': 4, '1': 5, '2': 6}}