Skip to content

Testutils

Utility functions useful for testing.

default_marks_for_config_name module-attribute #

default_marks_for_config_name: dict[
    str, list[MarkDecorator]
] = {
    "imagenet32": [slow],
    "inaturalist": [
        slow,
        skipif(
            not NETWORK_DIR and exists(),
            reason="Expects to be run on the Mila cluster for now",
        ),
    ],
    "imagenet": [
        slow,
        skipif(
            not NETWORK_DIR and exists(),
            reason="Expects to be run on a cluster with the ImageNet dataset.",
        ),
    ],
    "vision": [
        skip(
            reason="Base class, shouldn't be instantiated."
        )
    ],
}

Dict with some default marks for some configs name.

ParametrizedFixture #

Bases: Generic[T]

Small helper function that creates a parametrized pytest fixture for the given values.

The name of the fixture will be the name that is used for this variable on a class.

For example:

class TestFoo:
    odd = ParametrizedFixture([True, False])

    def test_something(self, odd: bool):
        '''some kind of test that uses odd'''

    # NOTE: This fixture can also be used by other fixtures:

    @pytest.fixture
    def some_number(self, odd: bool):
        return 1 if odd else 2

    def test_foo(self, some_number: int):
        '''some kind of test that uses some_number'''

parametrized_fixture #

parametrized_fixture(
    name: str, values: Sequence, ids=None, **kwargs
)

Small helper function that creates a parametrized pytest fixture for the given values.

NOTE: When writing a fixture in a test class, use ParametrizedFixture instead.

run_for_all_datamodules #

run_for_all_datamodules(
    datamodule_names: list[str] | None = None,
    datamodule_name_to_marks: (
        dict[str, MarkDecorator | list[MarkDecorator]]
        | None
    ) = None,
)

Apply this marker to a test to make it run with all available datasets (datamodules).

The test should use the datamodule fixture, either as an input argument to the test function or indirectly by using a fixture that depends on the datamodule fixture.

Parameters#

datamodule_names: List of datamodule names to use for tests. By default, lists out the generic datamodules (the datamodules that aren't specific to a single algorithm, for example the InfGendatamodules of WakeSleep.)

datamodule_to_marks: Dictionary from datamodule names to pytest marks (e.g. pytest.mark.xfail, pytest.mark.skip) to use for that particular datamodule.

run_for_all_configs_of_type #

run_for_all_configs_of_type(
    config_group: str,
    target_type: type,
    excluding: type | tuple[type, ...] = (),
)

Parametrizes a test to run with all the configs in the given group that have targets which are subclasses of the given type.

For example:

@run_for_all_configs_of_type("algorithm", torch.nn.Module)
def test_something_about_the_algorithm(algorithm: torch.nn.Module):
    ''' This test will run with all the configs in the 'algorithm' group that create nn.Modules! '''

Concretely, this works by indirectly parametrizing the f"{config_group}_config" fixture. To learn more about indirect parametrization in PyTest, take a look at https://docs.pytest.org/en/stable/example/parametrize.html#indirect-parametrization

parametrize_when_used #

parametrize_when_used(
    arg_name_or_fixture: str | Callable,
    values: list,
    indirect: bool | None = None,
) -> MarkDecorator

Fixture that applies pytest.mark.parametrize only when the argument is used (directly or indirectly).

When pytest.mark.parametrize is applied to a class, all test methods in that class need to use the parametrized argument, otherwise an error is raised. This function exists to work around this and allows writing test methods that don't use the parametrized argument.

For example, this works, but would not be possible with pytest.mark.parametrize:

import pytest

@parametrize_when_used("value", [1, 2, 3])
class TestFoo:
    def test_foo(self, value):
        ...

    def test_bar(self, value):
        ...

    def test_something_else(self):  # This will cause an error!
        pass

Parameters#

arg_name_or_fixture: The name of the argument to parametrize, or a fixture to parametrize indirectly. values: The values to be used to parametrize the test.

Returns#

A pytest.MarkDecorator that parametrizes the test with the given values only when the argument is used (directly or indirectly) by the test.

run_for_all_configs_in_group #

run_for_all_configs_in_group(
    group_name: str,
    config_name_to_marks: (
        Mapping[str, MarkDecorator | list[MarkDecorator]]
        | None
    ) = None,
)

Apply this marker to a test to make it run with all configs in a given group.

This assumes that a "group_name_config" fixture is defined, for example, algorithm_config, datamodule_config, network_config. This then does an indirect parametrization of that fixture, so that it receives the config name as a parameter and returns it.

The test wrapped test will uses all config from that group if they are used either as an input argument to the test function or if it the input argument to a fixture function.

Parameters#

datamodule_names: List of datamodule names to use for tests. By default, lists out the generic datamodules (the datamodules that aren't specific to a single algorithm, for example the InfGendatamodules of WakeSleep.)

datamodule_to_marks: Dictionary from datamodule names to pytest marks (e.g. pytest.mark.xfail, pytest.mark.skip) to use for that particular datamodule.