Skip to content

Testutils

Utility functions useful for testing.

default_marks_for_config_name module-attribute #

default_marks_for_config_name: dict[
    str, list[MarkDecorator]
] = {
    "inaturalist": [
        slow,
        skipif(
            not NETWORK_DIR and exists(),
            reason="Expects to be run on the Mila cluster for now",
        ),
    ],
    "imagenet": [
        slow,
        skipif(
            not NETWORK_DIR and exists(),
            reason="Expects to be run on a cluster with the ImageNet dataset.",
        ),
    ],
    "vision": [
        skip(
            reason="Base class, shouldn't be instantiated."
        )
    ],
}

Dict with some default marks for some configs name.

default_marks_for_config_combinations module-attribute #

default_marks_for_config_combinations: dict[
    tuple[str, ...], list[MarkDecorator]
] = {
    ("imagenet", "fcnet"): [
        xfail(
            reason="FcNet shouldn't be applied to the ImageNet datamodule. It can lead to nans in the parameters."
        )
    ],
    ("imagenet", "jax_fcnet"): [
        xfail(
            reason="FcNet shouldn't be applied to the ImageNet datamodule. It can lead to nans in the parameters."
        )
    ],
    ("imagenet", "jax_cnn"): [
        xfail(
            reason="todo: parameters contain nans when overfitting on one batch? Maybe we're using too many iterations?"
        )
    ],
    None: {
        (resnet_config, mnist_dataset_config): [
            skip(
                reason="ResNets don't work with MNIST datasets because the image resolution is too small."
            )
        ]
        for (
            resnet_config,
            mnist_dataset_config,
        ) in product(
            get_all_configs_in_group_of_type(
                "algorithm/network", ResNet
            ),
            get_all_configs_in_group_of_type(
                "datamodule",
                (MNISTDataModule, FashionMNISTDataModule),
            ),
        )
    },
}

Dict with some default marks to add to tests when some config combinations are present.

For example, ResNet networks can't be applied to the MNIST datasets.

get_target_of_config #

get_target_of_config(
    config_group: str,
    config_name: str,
    _cs: ConfigStore | None = None,
) -> Callable

Returns the class that is to be instantiated by the given config name.

In the case of inner dataclasses (e.g. Model.HParams), this returns the outer class (Model).

get_all_configs_in_group_of_type #

get_all_configs_in_group_of_type(
    config_group: str,
    config_target_type: type | tuple[type, ...],
    include_subclasses: bool = True,
    excluding: type | tuple[type, ...] = (),
) -> list[str]

Returns the names of all the configs in the given config group that have this target or a subclass of it.

run_for_all_configs_of_type #

run_for_all_configs_of_type(
    config_group: str,
    target_type: type,
    excluding: type | tuple[type, ...] = (),
)

Parametrizes a test to run with all the configs in the given group that have targets which are subclasses of the given type.

For example:

@run_for_all_configs_of_type("algorithm", torch.nn.Module)
def test_something_about_the_algorithm(algorithm: torch.nn.Module):
    ''' This test will run with all the configs in the 'algorithm' group that create nn.Modules! '''

Concretely, this works by indirectly parametrizing the f"{config_group}_config" fixture. To learn more about indirect parametrization in PyTest, take a look at https://docs.pytest.org/en/stable/example/parametrize.html#indirect-parametrization

parametrize_when_used #

parametrize_when_used(
    arg_name_or_fixture: str | Callable,
    values: list,
    indirect: bool | None = None,
) -> MarkDecorator

Fixture that applies pytest.mark.parametrize only when the argument is used (directly or indirectly).

When pytest.mark.parametrize is applied to a class, all test methods in that class need to use the parametrized argument, otherwise an error is raised. This function exists to work around this and allows writing test methods that don't use the parametrized argument.

For example, this works, but would not be possible with pytest.mark.parametrize:

import pytest

@parametrize_when_used("value", [1, 2, 3])
class TestFoo:
    def test_foo(self, value):
        ...

    def test_bar(self, value):
        ...

    def test_something_else(self):  # This will cause an error!
        pass

Parameters:

Name Type Description Default
arg_name_or_fixture str | Callable

The name of the argument to parametrize, or a fixture to parametrize indirectly.

required
values list

The values to be used to parametrize the test.

required

Returns:

Type Description
MarkDecorator

A pytest.MarkDecorator that parametrizes the test with the given values only when the argument is used (directly or indirectly) by the test.

run_for_all_configs_in_group #

run_for_all_configs_in_group(
    group_name: str,
    config_name_to_marks: (
        Mapping[str, MarkDecorator | list[MarkDecorator]]
        | None
    ) = None,
)

Apply this marker to a test to make it run with all configs in a given group.

This assumes that a "group_name_config" fixture is defined, for example, algorithm_config, datamodule_config, network_config. This then does an indirect parametrization of that fixture, so that it receives the config name as a parameter and returns it.

The test wrapped test will uses all config from that group if they are used either as an input argument to the test function or if it the input argument to a fixture function.

Parameters:

Name Type Description Default
group_name str

List of datamodule names to use for tests. By default, lists out the generic datamodules (the datamodules that aren't specific to a single algorithm, for example the InfGendatamodules of WakeSleep.)

required
config_name_to_marks Mapping[str, MarkDecorator | list[MarkDecorator]] | None

Dictionary from config names to pytest marks (e.g. pytest.mark.xfail, pytest.mark.skip) to use for that particular config.

None

total_vram_gb #

total_vram_gb() -> float

Returns the total VRAM in GB.