Lightning module tests
Suite of tests for an a LightningModule
.
See the project.algorithms.image_classifier_test module for an example of how to use this.
LightningModuleTests #
Bases: Generic[LightningModuleType]
, ABC
Suite of generic tests for a LightningModule.
Simply inherit from this class and decorate the class with the appropriate markers to get a set of decent unit tests that should apply to almost any LightningModule.
See the project.algorithms.image_classifier_test module for an example.
Other ideas: - pytest-benchmark for regression tests on forward / backward pass / training step speed - pytest-profiling for profiling the training step? (pytorch variant?) - Dataset splits: check some basic stats about the train/val/test inputs, are they somewhat similar? - Define the input as a space, check that the dataset samples are in that space and not too many samples are statistically OOD? - Test to monitor distributed traffic out of this process? - Dummy two-process tests (on CPU) to check before scaling up experiments?
config #
config(dict_config: DictConfig) -> Config
The experiment configuration, with all interpolations resolved.
algorithm #
algorithm(
config: Config,
datamodule: LightningDataModule | None,
trainer: Trainer,
device: device,
)
Fixture that creates the "algorithm" (usually a LightningModule
).
make_torch_deterministic #
Set torch to deterministic mode for unit tests that use the tensor_regression fixture.
seed #
seed(request: FixtureRequest)
Fixture that seeds everything for reproducibility and yields the random seed used.
training_step_content #
training_step_content(
datamodule: LightningDataModule | None,
algorithm: LightningModuleType,
seed: int,
accelerator: str,
devices: int | list[int],
tmp_path_factory: TempPathFactory,
)
Fixture that runs a training step and makes various things available for tests.
test_initialization_is_reproducible #
test_initialization_is_reproducible(
training_step_content: StuffFromFirstTrainingStep,
tensor_regression: TensorRegressionFixture,
accelerator: str,
)
Check that the network initialization is reproducible given the same random seed.
test_forward_pass_is_reproducible #
test_forward_pass_is_reproducible(
algorithm: LightningModuleType,
training_step_content: StuffFromFirstTrainingStep,
tensor_regression: TensorRegressionFixture,
)
Check that the forward pass is reproducible given the same input and random seed.
Note: There could be more than one call to forward
inside a training step. Here we only
check the args/kwargs/outputs of the first forward
call for now.
test_backward_pass_is_reproducible #
test_backward_pass_is_reproducible(
training_step_content: StuffFromFirstTrainingStep,
tensor_regression: TensorRegressionFixture,
accelerator: str,
)
Check that the backward pass is reproducible given the same weights, inputs and random seed.
test_update_is_reproducible #
test_update_is_reproducible(
algorithm: LightningModuleType,
training_step_content: StuffFromFirstTrainingStep,
tensor_regression: TensorRegressionFixture,
accelerator: str,
)
Check that the weights after one step of training are the same given the same seed.
do_one_step_of_training #
do_one_step_of_training(
algorithm: LightningModuleType,
datamodule: LightningDataModule | None,
accelerator: str,
devices: int | list[int] | Literal["auto"],
callbacks: list[Callback],
tmp_path: Path,
)
Performs one step of training.
Overwrite this if you train your algorithm differently.
StuffFromFirstTrainingStep
dataclass
#
Dataclass that holds information gathered from a training step and used in tests.
batch
class-attribute
instance-attribute
#
batch: Any | None = None
The input batch passed to the training_step
method.
forward_args
class-attribute
instance-attribute
#
The inputs args passed to each call to forward
during the training step.
forward_kwargs
class-attribute
instance-attribute
#
The inputs kwargs apssed to each call to forward
during the training step.
forward_outputs
class-attribute
instance-attribute
#
The outputs of each call to the forward
method during the training step.
initial_state_dict
class-attribute
instance-attribute
#
A copy of the state dict before the training step (moved to CPU).
grads
class-attribute
instance-attribute
#
A copy of the gradients of the model parameters after the backward pass (moved to CPU).
GetStuffFromFirstTrainingStep #
Bases: Callback
Callback used in tests to get things from the first call to training_step
.
convert_list_and_tuples_to_dicts #
Converts all lists and tuples in a nested structure to dictionaries.
convert_list_and_tuples_to_dicts([1, 2, 3]) {'0': 1, '1': 2, '2': 3} convert_list_and_tuples_to_dicts((1, 2, 3)) {'0': 1, '1': 2, '2': 3} convert_list_and_tuples_to_dicts({"a": [1, 2, 3], "b": (4, 5, 6)}) {'a': {'0': 1, '1': 2, '2': 3}, 'b': {'0': 4, '1': 5, '2': 6}}