Testutils
Utility functions useful for testing.
default_marks_for_config_name
module-attribute
#
default_marks_for_config_name: dict[
str, list[MarkDecorator]
] = {
"imagenet32": [slow],
"inaturalist": [
slow,
skipif(
not NETWORK_DIR and exists(),
reason="Expects to be run on the Mila cluster for now",
),
],
"imagenet": [
slow,
skipif(
not NETWORK_DIR and exists(),
reason="Expects to be run on a cluster with the ImageNet dataset.",
),
],
"vision": [
skip(
reason="Base class, shouldn't be instantiated."
)
],
}
Dict with some default marks for some configs name.
default_marks_for_config_combinations
module-attribute
#
default_marks_for_config_combinations: dict[
tuple[str, ...], list[MarkDecorator]
] = {
("imagenet", "fcnet"): [
xfail(
reason="FcNet shouldn't be applied to the ImageNet datamodule. It can lead to nans in the parameters."
)
],
("imagenet", "jax_fcnet"): [
xfail(
reason="FcNet shouldn't be applied to the ImageNet datamodule. It can lead to nans in the parameters."
)
],
("imagenet", "jax_cnn"): [
xfail(
reason="todo: parameters contain nans when overfitting on one batch? Maybe we're using too many iterations?"
)
],
None: {
(resnet_config, mnist_dataset_config): [
skip(
reason="ResNets don't work with MNIST datasets because the image resolution is too small."
)
]
for (
resnet_config,
mnist_dataset_config,
) in product(
get_all_configs_in_group_of_type(
"algorithm/network", ResNet
),
get_all_configs_in_group_of_type(
"datamodule",
(MNISTDataModule, FashionMNISTDataModule),
),
)
},
}
Dict with some default marks to add to tests when some config combinations are present.
For example, ResNet networks can't be applied to the MNIST datasets.
run_for_all_configs_of_type #
run_for_all_configs_of_type(
config_group: str,
target_type: type,
excluding: type | tuple[type, ...] = (),
)
Parametrizes a test to run with all the configs in the given group that have targets which are subclasses of the given type.
For example:
@run_for_all_configs_of_type("algorithm", torch.nn.Module)
def test_something_about_the_algorithm(algorithm: torch.nn.Module):
''' This test will run with all the configs in the 'algorithm' group that create nn.Modules! '''
Concretely, this works by indirectly parametrizing the f"{config_group}_config"
fixture.
To learn more about indirect parametrization in PyTest, take a look at
https://docs.pytest.org/en/stable/example/parametrize.html#indirect-parametrization
parametrize_when_used #
parametrize_when_used(
arg_name_or_fixture: str | Callable,
values: list,
indirect: bool | None = None,
) -> MarkDecorator
Fixture that applies pytest.mark.parametrize
only when the argument is used (directly or
indirectly).
When pytest.mark.parametrize
is applied to a class, all test methods in that class need to
use the parametrized argument, otherwise an error is raised. This function exists to work around
this and allows writing test methods that don't use the parametrized argument.
For example, this works, but would not be possible with pytest.mark.parametrize
:
import pytest
@parametrize_when_used("value", [1, 2, 3])
class TestFoo:
def test_foo(self, value):
...
def test_bar(self, value):
...
def test_something_else(self): # This will cause an error!
pass
Parameters:
Name | Type | Description | Default |
---|---|---|---|
arg_name_or_fixture
|
str | Callable
|
The name of the argument to parametrize, or a fixture to parametrize indirectly. |
required |
values
|
list
|
The values to be used to parametrize the test. |
required |
Returns:
Type | Description |
---|---|
MarkDecorator
|
A |
run_for_all_configs_in_group #
run_for_all_configs_in_group(
group_name: str,
config_name_to_marks: (
Mapping[str, MarkDecorator | list[MarkDecorator]]
| None
) = None,
)
Apply this marker to a test to make it run with all configs in a given group.
This assumes that a "group_name
_config" fixture is defined, for example, algorithm_config
,
datamodule_config
, network_config
. This then does an indirect parametrization of that fixture, so that it
receives the config name as a parameter and returns it.
The test wrapped test will uses all config from that group if they are used either as an input argument to the test function or if it the input argument to a fixture function.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
group_name
|
str
|
List of datamodule names to use for tests. By default, lists out the generic datamodules (the datamodules that aren't specific to a single algorithm, for example the InfGendatamodules of WakeSleep.) |
required |
config_name_to_marks
|
Mapping[str, MarkDecorator | list[MarkDecorator]] | None
|
Dictionary from config names to pytest marks (e.g. |
None
|