Skip to content

Lightning module tests

Suite of tests for an a LightningModule.

See the project.algorithms.image_classifier_test module for an example of how to use this.

LightningModuleTests #

Bases: Generic[LightningModuleType], ABC

Suite of generic tests for a LightningModule.

Simply inherit from this class and decorate the class with the appropriate markers to get a set of decent unit tests that should apply to almost any LightningModule.

See the project.algorithms.image_classifier_test module for an example.

Other ideas: - pytest-benchmark for regression tests on forward / backward pass / training step speed - pytest-profiling for profiling the training step? (pytorch variant?) - Dataset splits: check some basic stats about the train/val/test inputs, are they somewhat similar? - Define the input as a space, check that the dataset samples are in that space and not too many samples are statistically OOD? - Test to monitor distributed traffic out of this process? - Dummy two-process tests (on CPU) to check before scaling up experiments?

config #

config(dict_config: DictConfig) -> Config

The experiment configuration, with all interpolations resolved.

algorithm #

algorithm(
    config: Config,
    datamodule: LightningDataModule | None,
    trainer: Trainer,
    device: device,
)

Fixture that creates the "algorithm" (usually a LightningModule).

make_torch_deterministic #

make_torch_deterministic()

Set torch to deterministic mode for unit tests that use the tensor_regression fixture.

seed #

seed(request: FixtureRequest)

Fixture that seeds everything for reproducibility and yields the random seed used.

training_step_content #

training_step_content(
    datamodule: LightningDataModule | None,
    algorithm: LightningModuleType,
    seed: int,
    accelerator: str,
    devices: int | list[int],
    tmp_path_factory: TempPathFactory,
)

Fixture that runs a training step and makes various things available for tests.

test_initialization_is_reproducible #

test_initialization_is_reproducible(
    training_step_content: StuffFromFirstTrainingStep,
    tensor_regression: TensorRegressionFixture,
    accelerator: str,
)

Check that the network initialization is reproducible given the same random seed.

test_forward_pass_is_reproducible #

test_forward_pass_is_reproducible(
    algorithm: LightningModuleType,
    training_step_content: StuffFromFirstTrainingStep,
    tensor_regression: TensorRegressionFixture,
)

Check that the forward pass is reproducible given the same input and random seed.

Note: There could be more than one call to forward inside a training step. Here we only check the args/kwargs/outputs of the first forward call for now.

test_backward_pass_is_reproducible #

test_backward_pass_is_reproducible(
    training_step_content: StuffFromFirstTrainingStep,
    tensor_regression: TensorRegressionFixture,
    accelerator: str,
)

Check that the backward pass is reproducible given the same weights, inputs and random seed.

test_update_is_reproducible #

test_update_is_reproducible(
    algorithm: LightningModuleType,
    training_step_content: StuffFromFirstTrainingStep,
    tensor_regression: TensorRegressionFixture,
    accelerator: str,
)

Check that the weights after one step of training are the same given the same seed.

do_one_step_of_training #

do_one_step_of_training(
    algorithm: LightningModuleType,
    datamodule: LightningDataModule | None,
    accelerator: str,
    devices: int | list[int] | Literal["auto"],
    callbacks: list[Callback],
    tmp_path: Path,
)

Performs one step of training.

Overwrite this if you train your algorithm differently.

StuffFromFirstTrainingStep dataclass #

Dataclass that holds information gathered from a training step and used in tests.

batch class-attribute instance-attribute #

batch: Any | None = None

The input batch passed to the training_step method.

forward_args class-attribute instance-attribute #

forward_args: list[tuple[Any, ...]] = field(
    default_factory=list
)

The inputs args passed to each call to forward during the training step.

forward_kwargs class-attribute instance-attribute #

forward_kwargs: list[dict[str, Any]] = field(
    default_factory=list
)

The inputs kwargs apssed to each call to forward during the training step.

forward_outputs class-attribute instance-attribute #

forward_outputs: list[Any] = field(default_factory=list)

The outputs of each call to the forward method during the training step.

initial_state_dict class-attribute instance-attribute #

initial_state_dict: dict[str, Tensor] = field(
    default_factory=dict
)

A copy of the state dict before the training step (moved to CPU).

grads class-attribute instance-attribute #

grads: dict[str, Tensor | None] = field(
    default_factory=dict
)

A copy of the gradients of the model parameters after the backward pass (moved to CPU).

training_step_output class-attribute instance-attribute #

training_step_output: Tensor | Mapping[str, Any] | None = (
    None
)

The output of the training_step method.

GetStuffFromFirstTrainingStep #

Bases: Callback

Callback used in tests to get things from the first call to training_step.

convert_list_and_tuples_to_dicts #

convert_list_and_tuples_to_dicts(value: Tensor) -> Tensor
convert_list_and_tuples_to_dicts(
    value: dict | tuple | list,
) -> dict[str, Any]
convert_list_and_tuples_to_dicts(
    value: Tensor | dict | tuple | list,
) -> Tensor | dict[str, Any]

Converts all lists and tuples in a nested structure to dictionaries.

convert_list_and_tuples_to_dicts([1, 2, 3]) {'0': 1, '1': 2, '2': 3} convert_list_and_tuples_to_dicts((1, 2, 3)) {'0': 1, '1': 2, '2': 3} convert_list_and_tuples_to_dicts({"a": [1, 2, 3], "b": (4, 5, 6)}) {'a': {'0': 1, '1': 2, '2': 3}, 'b': {'0': 4, '1': 5, '2': 6}}