Skip to content

Datamodules

Datamodules (datasets + preprocessing + dataloading)

See the :ref:lightning.LightningDataModule class for more information.

ImageClassificationDataModule #

Bases: VisionDataModule[BatchType], ClassificationDataModule[BatchType]

Lightning data modules for image classification.

num_classes instance-attribute #

num_classes: int

Number of classes in the dataset.

dims instance-attribute #

dims: tuple[C, H, W]

A tuple describing the shape of the data.

CIFAR10DataModule #

Bases: ImageClassificationDataModule, VisionDataModule

.. figure:: https://3qeqpr26caki16dnhd19sv6by6v-wpengine.netdna-ssl.com/wp-content/uploads/2019/01/ Plot-of-a-Subset-of-Images-from-the-CIFAR-10-Dataset.png :width: 400 :alt: CIFAR-10

Specs
  • 10 classes (1 per class)
  • Each image is (3 x 32 x 32)

Standard CIFAR10, train, val, test splits and transforms

Transforms::

transforms = transform_lib.Compose([
    transform_lib.ToImage(),
    transform_lib.ToDtype(torch.float32, scale=True),
    transform_lib.Normalize(
        mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
        std=[x / 255.0 for x in [63.0, 62.1, 66.7]]
    )
])

Example::

from pl_bolts.datamodules import CIFAR10DataModule

dm = CIFAR10DataModule(PATH)
model = LitModel()

Trainer().fit(model, datamodule=dm)

Or you can set your own transforms

Example::

dm.train_transforms = ...
dm.test_transforms = ...
dm.val_transforms  = ...

FashionMNISTDataModule #

Bases: MNISTDataModule

.. figure:: https://3qeqpr26caki16dnhd19sv6by6v-wpengine.netdna-ssl.com/ wp-content/uploads/2019/02/Plot-of-a-Subset-of-Images-from-the-Fashion-MNIST-Dataset.png :width: 400 :alt: Fashion MNIST

Specs
  • 10 classes (1 per type)
  • Each image is (1 x 28 x 28)

Standard FashionMNIST, train, val, test splits and transforms

Transforms::

mnist_transforms = transform_lib.Compose([
    transform_lib.ToTensor()
])

Example::

from pl_bolts.datamodules import FashionMNISTDataModule

dm = FashionMNISTDataModule('.')
model = LitModel()

Trainer().fit(model, datamodule=dm)

ImageNetDataModule #

Bases: VisionDataModule

ImageNet datamodule.

Extracted from https://github.com/Lightning-Universe/lightning-bolts/blob/master/src/pl_bolts/datamodules/imagenet_datamodule.py - Made this a subclass of VisionDataModule

Notes: - train_dataloader uses the train split of imagenet2012 and puts away a portion of it for the validation split. - val_dataloader uses the part of the train split of imagenet2012 that was not used for training via num_imgs_per_val_class - TODO: needs to pass split='val' to UnlabeledImagenet. - test_dataloader uses the validation split of imagenet2012 for testing. - TODO: need to pass num_imgs_per_class=-1 for test dataset and split="test".

name class-attribute #

name: str = 'imagenet'

Dataset name.

dataset_cls class-attribute #

dataset_cls: type[ImageNet] = ImageNet

Dataset class to use.

dims class-attribute instance-attribute #

dims: tuple[C, H, W] = (
    C(3),
    H(image_size),
    W(image_size),
)

A tuple describing the shape of the data.

__init__ #

__init__(
    data_dir: str | Path = DATA_DIR,
    *,
    val_split: int | float = 0.01,
    num_workers: int = NUM_WORKERS,
    normalize: bool = False,
    image_size: int = 224,
    batch_size: int = 32,
    seed: int = 42,
    shuffle: bool = True,
    pin_memory: bool = True,
    drop_last: bool = False,
    train_transforms: Callable | None = None,
    val_transforms: Callable | None = None,
    test_transforms: Callable | None = None,
    **kwargs
)

Creates an ImageNet datamodule (doesn't load or prepare the dataset yet).

Parameters:

Name Type Description Default
data_dir str | Path

path to the imagenet dataset file

DATA_DIR
val_split int | float

save val_split% of the training data of each class for validation.

0.01
image_size int

final image size

224
num_workers int

how many data workers

NUM_WORKERS
batch_size int

batch_size

32
shuffle bool

If true shuffles the data every epoch

True
pin_memory bool

If true, the data loader will copy Tensors into CUDA pinned memory before returning them

True
drop_last bool

If true drops the last incomplete batch

False

train_transform #

train_transform() -> Module[[Tensor], Tensor]

The standard imagenet transforms.

.. code-block:: python

transform_lib.Compose([
    transform_lib.RandomResizedCrop(self.image_size),
    transform_lib.RandomHorizontalFlip(),
    transform_lib.ToTensor(),
    transform_lib.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
])

val_transform #

val_transform() -> Callable

The standard imagenet transforms for validation.

.. code-block:: python

transform_lib.Compose([
    transform_lib.Resize(self.image_size + 32),
    transform_lib.CenterCrop(self.image_size),
    transform_lib.ToTensor(),
    transform_lib.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
])

ImageNet32DataModule #

Bases: VisionDataModule

TODO: Add a val_split argument, that supports a value of 0.

prepare_data #

prepare_data() -> None

Saves files to data_dir.

default_transforms #

default_transforms() -> Callable

Default transform for the dataset.

train_dataloader #

train_dataloader() -> DataLoader

The train dataloader.

val_dataloader #

val_dataloader() -> DataLoader

The val dataloader.

test_dataloader #

test_dataloader() -> DataLoader

The test dataloader.

INaturalistDataModule #

Bases: ImageClassificationDataModule

name class-attribute #

name: str = 'inaturalist'

Dataset name.

dataset_cls class-attribute #

dataset_cls: type[INaturalist] = INaturalist

Dataset class to use.

dims class-attribute instance-attribute #

dims: tuple[C, H, W] = (C(3), H(224), W(224))

A tuple describing the shape of the data.

default_transforms #

default_transforms() -> Callable

Default transform for the dataset.

MNISTDataModule #

Bases: ImageClassificationDataModule

.. figure:: https://miro.medium.com/max/744/1*AO2rIhzRYzFVQlFLx9DM9A.png :width: 400 :alt: MNIST

Specs
  • 10 classes (1 per digit)
  • Each image is (1 x 28 x 28)

Standard MNIST, train, val, test splits and transforms

Transforms::

mnist_transforms = transform_lib.Compose([
    transform_lib.ToTensor()
])

Example::

from pl_bolts.datamodules import MNISTDataModule

dm = MNISTDataModule('.')
model = LitModel()

Trainer().fit(model, datamodule=dm)

__init__ #

__init__(
    data_dir: str | None = None,
    val_split: int | float = 0.2,
    num_workers: int | None = 0,
    normalize: bool = False,
    batch_size: int = 32,
    seed: int = 42,
    shuffle: bool = True,
    pin_memory: bool = True,
    drop_last: bool = False,
    *args: Any,
    **kwargs: Any
) -> None

Parameters:

Name Type Description Default
data_dir str | None

Where to save/load the data

None
val_split int | float

Percent (float) or number (int) of samples to use for the validation split

0.2
num_workers int | None

How many workers to use for loading data

0
normalize bool

If true applies image normalize

False
batch_size int

How many samples per batch to load

32
seed int

Random seed to be used for train/val/test splits

42
shuffle bool

If true shuffles the train data every epoch

True
pin_memory bool

If true, the data loader will copy Tensors into CUDA pinned memory before returning them

True
drop_last bool

If true drops the last incomplete batch

False

TextClassificationDataModule #

Bases: LightningDataModule

Lightning data module for HF text classification datasets.

This is based on this tutorial: https://lightning.ai/docs/pytorch/stable/notebooks/lightning_examples/text-transformers.html

VisionDataModule #

Bases: LightningDataModule, DataModule[BatchType_co]

A LightningDataModule for image datasets.

(Taken from pl_bolts which is not very well maintained.)

name class-attribute #

name: str = ''

Dataset name.

dataset_cls class-attribute #

dataset_cls: type[VisionDataset]

Dataset class to use.

dims class-attribute #

dims: tuple[C, H, W]

A tuple describing the shape of the data.

__init__ #

__init__(
    data_dir: str | Path = DATA_DIR,
    val_split: int | float = 0.2,
    num_workers: int = NUM_WORKERS,
    normalize: bool = False,
    batch_size: int = 32,
    seed: int = 42,
    shuffle: bool = True,
    pin_memory: bool = True,
    drop_last: bool = False,
    train_transforms: Callable | None = None,
    val_transforms: Callable | None = None,
    test_transforms: Callable | None = None,
    **kwargs
) -> None

Parameters:

Name Type Description Default
data_dir str | Path

Where to save/load the data

DATA_DIR
val_split int | float

Percent (float) or number (int) of samples to use for the validation split

0.2
num_workers int

How many workers to use for loading data

NUM_WORKERS
normalize bool

If true applies image normalize

False
batch_size int

How many samples per batch to load

32
seed int

Random seed to be used for train/val/test splits

42
shuffle bool

If true shuffles the train data every epoch

True
pin_memory bool

If true, the data loader will copy Tensors into CUDA pinned memory before returning them

True
drop_last bool

If true drops the last incomplete batch

False
train_transforms Callable | None

transformations you can apply to train dataset

None
val_transforms Callable | None

transformations you can apply to validation dataset

None
test_transforms Callable | None

transformations you can apply to test dataset

None

prepare_data #

prepare_data() -> None

Saves files to data_dir.

default_transforms abstractmethod #

default_transforms() -> Callable

Default transform for the dataset.

train_dataloader #

train_dataloader(
    _dataloader_fn: Callable[
        Concatenate[Dataset, P], DataLoader
    ] = DataLoader,
    *args: args,
    **kwargs: kwargs
) -> DataLoader

The train dataloader.

val_dataloader #

val_dataloader(
    _dataloader_fn: Callable[
        Concatenate[Dataset, P], DataLoader
    ] = DataLoader,
    *args: args,
    **kwargs: kwargs
) -> DataLoader

The val dataloader.

test_dataloader #

test_dataloader(
    _dataloader_fn: Callable[
        Concatenate[Dataset, P], DataLoader
    ] = DataLoader,
    *args: args,
    **kwargs: kwargs
) -> DataLoader

The test dataloader.