Algorithm tests
Suite of tests for an "algorithm".
See the project.algorithms.example_test module for an example of how to use this.
LearningAlgorithmTests #
Bases: Generic[AlgorithmType]
, ABC
Suite of unit tests for an "Algorithm" (LightningModule).
Simply inherit from this class and decorate the class with the appropriate markers to get a set of decent unit tests that should apply to any LightningModule.
See the project.algorithms.example_test module for an example.
test_initialization_is_deterministic #
test_initialization_is_deterministic(
experiment_config: Config,
datamodule: DataModule,
seed: int,
)
Checks that the weights initialization is consistent given the a random seed.
test_forward_pass_is_deterministic #
Checks that the forward pass output is consistent given the a random seed and a given input.
test_backward_pass_is_deterministic #
test_backward_pass_is_deterministic(
datamodule: LightningDataModule,
algorithm: AlgorithmType,
seed: int,
accelerator: str,
devices: int | list[int] | Literal["auto"],
tmp_path: Path,
)
Check that the backward pass is reproducible given the same input, weights, and random seed.
test_initialization_is_reproducible #
test_initialization_is_reproducible(
experiment_config: Config,
datamodule: DataModule,
seed: int,
tensor_regression: TensorRegressionFixture,
)
Check that the network initialization is reproducible given the same random seed.
test_forward_pass_is_reproducible #
test_forward_pass_is_reproducible(
forward_pass_input: Any,
algorithm: AlgorithmType,
seed: int,
tensor_regression: TensorRegressionFixture,
)
Check that the forward pass is reproducible given the same input and random seed.
test_backward_pass_is_reproducible #
test_backward_pass_is_reproducible(
datamodule: LightningDataModule,
algorithm: AlgorithmType,
seed: int,
accelerator: str,
devices: int | list[int],
tensor_regression: TensorRegressionFixture,
tmp_path: Path,
)
Check that the backward pass is reproducible given the same weights, inputs and random seed.
forward_pass_input #
Extracts the model input from a batch of data coming from the dataloader.
Overwrite this if your batches are not tuples of tensors (i.e. if your algorithm isn't a simple supervised learning algorithm like the example).
do_one_step_of_training #
do_one_step_of_training(
algorithm: AlgorithmType,
datamodule: LightningDataModule,
accelerator: str,
devices: int | list[int] | Literal["auto"],
callbacks: list[Callback],
tmp_path: Path,
)
Performs one step of training.
Overwrite this if you train your algorithm differently.